package distribute

import (
	"encoding/json"
	"math/big"
	"sort"
	"strconv"
	"sync"
	"time"
)

type ModelUsedLevel int

const (
	ModelUsedLevelSuperLow  ModelUsedLevel = iota //  < 2%
	ModelUsedLevelVeryLow                         // 2% ~ 5%
	ModelUsedLevelLow                             // 5% ~ 10%
	ModelUsedLevelMiddle                          // 10% ~ 30%
	ModelUsedLevelHigh                            // 30% ~ 50%
	ModelUsedLevelVeryHigh                        // 50% ~ 80%
	ModelUsedLevelSuperHigh                       //  >= 80%
)

func getModelLevel(count int, total int) ModelUsedLevel {
	usedRate := float64(count) / float64(total)
	if usedRate < 0.02 {
		return ModelUsedLevelSuperLow
	} else if usedRate < 0.05 {
		return ModelUsedLevelVeryLow
	} else if usedRate < 0.1 {
		return ModelUsedLevelLow
	} else if usedRate < 0.3 {
		return ModelUsedLevelMiddle
	} else if usedRate < 0.5 {
		return ModelUsedLevelHigh
	} else if usedRate < 0.8 {
		return ModelUsedLevelVeryHigh
	} else {
		return ModelUsedLevelSuperHigh
	}
}

type HardwareRequireInfo struct {
	DiskSize string `json:"disk_size"`
	Gpus     []struct {
		Gpu string `json:"gpu"`
	}
	MemorySize string `json:"memory_size"`
}

func (h HardwareRequireInfo) IntDiskSize() int64 {
	size, _ := strconv.ParseInt(h.DiskSize, 10, 64)
	return size
}

func (h HardwareRequireInfo) IntMemorySize() int64 {
	size, _ := strconv.ParseInt(h.MemorySize, 10, 64)
	return size
}

func (h HardwareRequireInfo) IntGpu(idx int) int {
	if idx < len(h.Gpus) {
		gpu, _ := strconv.Atoi(h.Gpus[idx].Gpu)
		return gpu
	}
	return 0
}

type ModelDetailInfo struct {
	Time            time.Time           `json:"time"`
	Count           int                 `json:"count"`
	HardwareRequire HardwareRequireInfo `json:"hardware_require"`
	ImageName       string              `json:"image_name"`
	SignURL         string              `json:"sign_url"`
	TaskID          int                 `json:"task_id"`
	Kind            int                 `json:"kind"`
	FileExpiresTime string              `json:"file_expires_time"`
	AccessStatus    int                 `json:"access_status"`
	PublishStatus   int                 `json:"publish_status"`
	EstimateExeTime int                 `json:"estimat_exe_time"`
	StartUpTime     int                 `json:"start_up_time"`
	RunningMem      int                 `json:"running_mem"`
	Cmd             json.RawMessage     `json:"cmd"`
}

// implement the sort interface
func (s SortedModelDetailInfos) Len() int {
	return len(s)
}

func (s SortedModelDetailInfos) Less(i, j int) bool {
	if s[i].Count == s[j].Count {
		return s[i].TaskID < s[j].TaskID
	}
	return s[i].Count < s[j].Count
}

func (s SortedModelDetailInfos) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

type SortedModelDetailInfos []ModelDetailInfo

type HeapModelInfos struct {
	mux      sync.Mutex
	models   SortedModelDetailInfos `json:"models"`
	modelMap map[int]ModelDetailInfo
	totalHot *big.Int
}

var (
	_ ModelLibrary = (*HeapModelInfos)(nil)
)

func NewHeapModelInfos(models []ModelDetailInfo) *HeapModelInfos {
	sort.Sort(SortedModelDetailInfos(models))
	hm := &HeapModelInfos{
		models:   models,
		totalHot: big.NewInt(0),
		modelMap: make(map[int]ModelDetailInfo),
	}
	for _, model := range models {
		hm.modelMap[model.TaskID] = model
		hm.totalHot.Add(hm.totalHot, big.NewInt(int64(model.Count)))
	}
	return hm
}

func (h *HeapModelInfos) UpdateModelInfo(models []ModelDetailInfo) {
	h.setModels(models)
}

func (h *HeapModelInfos) setModels(models []ModelDetailInfo) {
	h.mux.Lock()
	defer h.mux.Unlock()
	sort.Sort(SortedModelDetailInfos(models))
	h.models = models
	h.totalHot = big.NewInt(0)
	h.modelMap = make(map[int]ModelDetailInfo)
	for _, model := range models {
		h.modelMap[model.TaskID] = model
		h.totalHot.Add(h.totalHot, big.NewInt(int64(model.Count)))
	}
}

func (h *HeapModelInfos) GetSortedModels() []ModelDetailInfo {
	h.mux.Lock()
	defer h.mux.Unlock()
	return h.models
}

func (h *HeapModelInfos) GetModelUsedLevel(modelID int) ModelUsedLevel {
	h.mux.Lock()
	defer h.mux.Unlock()
	if model, ok := h.modelMap[modelID]; ok {
		return getModelLevel(model.Count, int(h.totalHot.Int64()))
	}

	return ModelUsedLevelSuperLow
}

func (h *HeapModelInfos) FindModel(i int) ModelDetailInfo {
	h.mux.Lock()
	defer h.mux.Unlock()
	if model, ok := h.modelMap[i]; ok {
		return model
	}
	return ModelDetailInfo{}
}

func (h *HeapModelInfos) FindModelByName(s string) ModelDetailInfo {
	h.mux.Lock()
	defer h.mux.Unlock()
	for _, model := range h.models {
		if model.ImageName == s {
			return model
		}
	}
	return ModelDetailInfo{}
}

func (h *HeapModelInfos) InstalledWorkerCount(i int) int {
	// todo: query the count from mongo.
	return 0
}

func (h HeapModelInfos) AllModel() SortedModelDetailInfos {
	h.mux.Lock()
	defer h.mux.Unlock()
	m := make(SortedModelDetailInfos, len(h.models))
	copy(m, h.models)
	return m
}
