package server

import (
	"errors"
	"fmt"
	lru "github.com/hashicorp/golang-lru"
	"github.com/odysseus/nodemanager/utils"
	odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1"
	omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v1"
	"github.com/redis/go-redis/v9"
	log "github.com/sirupsen/logrus"
	"golang.org/x/crypto/sha3"
	"math/big"
	"sync"
	"time"
)

var (
	ErrWorkerExist      = errors.New("worker exist")
	ErrHeartBeatExpired = errors.New("worker heartbeat expired")
)

type dispatchTask struct {
	task  *odysseus.TaskContent
	errCh chan error
}

type Worker struct {
	quit       chan interface{}
	taskCh     chan *dispatchTask
	resultCh   chan *omanager.SubmitTaskResult
	uuid       int64
	publicKey  string
	deviceInfo []*omanager.DeviceInfo
	recentTask *lru.Cache
	stream     omanager.NodeManagerService_RegisterWorkerServer
}

type WorkerManager struct {
	rdb       *redis.Client
	heartBeat map[int64]int64
	hbRwLock  sync.RWMutex

	workers  map[int64]*Worker
	wkRwLock sync.RWMutex
	quit     chan struct{}

	node *Node
}

func NewWorkerManager(rdb *redis.Client, node *Node) *WorkerManager {
	return &WorkerManager{
		heartBeat: make(map[int64]int64),
		workers:   make(map[int64]*Worker),
		quit:      make(chan struct{}),
		rdb:       rdb,
		node:      node,
	}
}

func (wm *WorkerManager) Stop() {
	close(wm.quit)
}

func (wm *WorkerManager) UpdateHeartBeat(uuid int64) {
	wm.hbRwLock.Lock()
	defer wm.hbRwLock.Unlock()
	wm.heartBeat[uuid] = time.Now().Unix()
}

func (wm *WorkerManager) UpdateStatus(worker *Worker) {

}

func (wm *WorkerManager) GetHeartBeat(uuid int64) int64 {
	wm.hbRwLock.RLock()
	defer wm.hbRwLock.RUnlock()
	return wm.heartBeat[uuid]
}

func (wm *WorkerManager) GetWorker(uuid int64) *Worker {
	wm.wkRwLock.RLock()
	defer wm.wkRwLock.RUnlock()

	return wm.workers[uuid]
}

func (wm *WorkerManager) AddNewWorker(uuid int64, worker omanager.NodeManagerService_RegisterWorkerServer) (*Worker, error) {
	wm.wkRwLock.Lock()
	defer wm.wkRwLock.Unlock()

	if _, exist := wm.workers[uuid]; exist {
		return nil, ErrWorkerExist
	}
	w := &Worker{
		taskCh:   make(chan *dispatchTask),
		resultCh: make(chan *omanager.SubmitTaskResult),
		uuid:     uuid,
		stream:   worker,
	}
	taskCache, err := lru.New(100)
	if err != nil {
		return nil, err
	}
	w.recentTask = taskCache
	wm.workers[uuid] = w
	go wm.handleWorkerMsg(w)

	return w, nil
}

type Callback func(err error) bool

func (wm *WorkerManager) manageWorker(worker *Worker) error {

	log.WithField("worker", worker.uuid).Info("start manage worker")
	defer log.WithField("worker", worker.uuid).Info("exit manage worker")

	heartBeatDuration := time.Second * 10
	workerCheckDuration := heartBeatDuration * 3

	heartBeatTicker := time.NewTicker(heartBeatDuration)
	defer heartBeatTicker.Stop()

	workerCheckTicker := time.NewTicker(workerCheckDuration)
	defer workerCheckTicker.Stop()

	statusTicker := time.NewTicker(time.Second * 10)
	defer statusTicker.Stop()

	deviceInfoTicker := time.NewTicker(time.Second * 10)
	defer deviceInfoTicker.Stop()

	deviceUsageTicker := time.NewTicker(time.Second * 10)
	defer deviceUsageTicker.Stop()

	for {
		var msg = new(omanager.ManagerMessage)
		var callback = Callback(func(err error) bool {
			// do nothing
			return true
		})

		select {
		case <-wm.quit:
			gb := new(omanager.ManagerMessage_GoodbyeMessage)
			gb.GoodbyeMessage = &omanager.GoodbyeMessage{}
			msg.Message = gb
		case <-workerCheckTicker.C:
			if time.Now().Unix()-wm.GetHeartBeat(worker.uuid) > int64(workerCheckDuration.Seconds()) {
				wm.InActiveWorker(worker)
				// remove worker
				close(worker.quit)
				return ErrHeartBeatExpired
			}

		case <-heartBeatTicker.C:
			hb := new(omanager.ManagerMessage_HeartbeatRequest)
			hb.HeartbeatRequest = &omanager.HeartbeatRequest{
				Timestamp: uint64(time.Now().Unix()),
			}
			msg.Message = hb

		case <-deviceInfoTicker.C:
			deviceInfo := new(omanager.ManagerMessage_DeviceRequest)
			deviceInfo.DeviceRequest = &omanager.DeviceInfoRequest{}
			msg.Message = deviceInfo
			callback = func(err error) bool {
				if err == nil {
					deviceInfoTicker.Reset(time.Second * 180)
				}
				return true
			}

		case <-deviceUsageTicker.C:
			deviceUsage := new(omanager.ManagerMessage_DeviceUsage)
			deviceUsage.DeviceUsage = &omanager.DeviceUsageRequest{}
			msg.Message = deviceUsage
			callback = func(err error) bool {
				if err == nil {
					deviceUsageTicker.Reset(time.Second * 180)
				}
				return true
			}

		case <-statusTicker.C:
			status := new(omanager.ManagerMessage_StatusRequest)
			status.StatusRequest = &omanager.StatusRequest{}
			msg.Message = status
			callback = func(err error) bool {
				if err == nil {
					statusTicker.Reset(time.Second * 120)
				}
				return true
			}

		case dtask, ok := <-worker.taskCh:
			if !ok {
				return nil
			}
			task := dtask.task
			taskMsg := new(omanager.ManagerMessage_PushTaskMessage)
			taskMsg.PushTaskMessage = &omanager.PushTaskMessage{
				TaskId:    task.TaskId,
				TaskType:  task.TaskType,
				Workload:  0,
				TaskCmd:   task.TaskCmd,
				TaskParam: task.TaskParam,
			}
			msg.Message = taskMsg
			callback = func(err error) bool {
				if err == nil {
					// add task to cache.
					worker.recentTask.Add(task.TaskId, task)
				}

				select {
				case dtask.errCh <- err:
				default:
					// err ch is invalid
				}
				return true
			}
		case result := <-worker.resultCh:
			// verify result and make a new signature.
			data, exist := worker.recentTask.Get(result.TaskId)
			if !exist {
				log.WithField("worker", worker.uuid).Error("task not found for verify result")
				continue
			}
			task := data.(*odysseus.TaskContent)
			if result.TaskId != task.TaskId {
				log.WithField("worker", worker.uuid).Error("task id not match")
				continue
			}
			// todo: verify container_signature and miner_signature
			//manager_signature = sign(hash((task_id+hash(task_param)+hash(task_result)+container_signature+miner_signature+workload))
			paramHash := sha3.Sum256(task.TaskParam)
			resultHash := sha3.Sum256(result.TaskResult)
			dataHash := sha3.Sum256(utils.CombineBytes([]byte(result.TaskId), paramHash[:], resultHash[:],
				result.ContainerSignature, result.MinerSignature, big.NewInt(0).Bytes()))
			//result.ContainerSignature, result.MinerSignature, big.NewInt(int64(task.Workload)).Bytes()))

			signature, err := wm.node.Sign(dataHash[:])
			if err != nil {
				log.WithError(err).Error("sign result failed")
				continue
			}

			proof := new(omanager.ManagerMessage_ProofTaskResult)
			proof.ProofTaskResult = &omanager.ProofTaskResult{
				TaskId:           result.TaskId,
				ManagerSignature: signature,
			}
			callback = func(err error) bool {
				if err == nil {
					// remove task from cache.
					worker.recentTask.Remove(result.TaskId)
				}
				// todo: post event for task succeed or failed
				return true
			}
		}

		if msg != nil {
			err := worker.stream.Send(msg)
			if err != nil {
				log.WithError(err).Error("send message to worker failed")
			}
			callback(err)
		}
	}

	return nil
}
func (wm *WorkerManager) handleWorkerMsg(worker *Worker) {
	log.WithField("worker", worker.uuid).Info("start handle worker message")
	defer log.WithField("worker", worker.uuid).Info("exit handle worker message")
	for {
		select {
		case <-wm.quit:
			return
		case <-worker.quit:
			return
		default:
			wmsg, err := worker.stream.Recv()
			if err != nil {
				log.WithError(err).WithField("worker", worker.uuid).Error("recv msg failed")
				close(worker.quit)
				return
			}
			switch msg := wmsg.Message.(type) {
			case *omanager.WorkerMessage_GoodbyeMessage:
				worker.quit <- msg.GoodbyeMessage.Reason
				close(worker.taskCh)
				return
			case *omanager.WorkerMessage_SubmitTaskResult:
				worker.resultCh <- msg.SubmitTaskResult
			case *omanager.WorkerMessage_HeartbeatResponse:
				wm.UpdateHeartBeat(worker.uuid)
				log.WithFields(log.Fields{
					"worker":   worker.uuid,
					"hearBeat": time.Now().Unix() - int64(msg.HeartbeatResponse.Timestamp),
				}).Debug("receive worker heartbeat")
			case *omanager.WorkerMessage_Status:
				// todo: store worker status
				log.WithFields(log.Fields{
					"worker": worker.uuid,
				}).Debugf("receive worker status:0x%x", msg.Status.DeviceStatus)
			case *omanager.WorkerMessage_DeviceInfo:
				// todo: handler worker device info
				log.WithFields(log.Fields{
					"worker": worker.uuid,
				}).Debugf("receive worker device info:%v", msg.DeviceInfo.Devices)
				if worker.deviceInfo == nil {
					// first time receive device info
					worker.publicKey = msg.DeviceInfo.MinerPubkey
					worker.deviceInfo = msg.DeviceInfo.Devices
					wm.AddWorker(worker)
				}

			case *omanager.WorkerMessage_DeviceUsage:
				// todo: handler worker device usage
				log.WithFields(log.Fields{
					"worker": worker.uuid,
				}).Debugf("receive worker device usage:%v", msg.DeviceUsage.Usage)

			default:
				log.WithField("worker", worker.uuid).Error(fmt.Sprintf("unsupport msg type %T", msg))
			}
		}
	}
}
