package controllers

import (
	"ai_developer_admin/libs/jose"
	"ai_developer_admin/libs/kong"
	"ai_developer_admin/libs/mysql"
	"ai_developer_admin/libs/odysseus"
	"ai_developer_admin/libs/postgres"
	"ai_developer_admin/libs/redis"
	"ai_developer_admin/libs/utils"
	"ai_developer_admin/models"
	"crypto/md5"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/beego/beego/orm"
	"github.com/beego/beego/v2/core/logs"
	beego "github.com/beego/beego/v2/server/web"
	"io"
	"net/http"
	"strconv"
	"strings"
	"time"
)

type UserController struct {
	MainController
}

//func (u *UserController) respond(code int, message string, data ...interface{}) {
//	u.Ctx.Output.SetStatus(code)
//	var d interface{}
//	if len(data) > 0 {
//		d = data[0]
//	}
//	u.Data["json"] = struct {
//		Code    int         `json:"code"`
//		Message string      `json:"message"`
//		Data    interface{} `json:"data,omitempty"`
//	}{
//		Code:    code,
//		Message: message,
//		Data:    d,
//	}
//	u.ServeJSON()
//}

//func (u *UserController) Test() {
//	info, err := u.Check()
//	if err != nil {
//		u.respond(http.StatusUnauthorized, err.Error())
//		return
//	}
//	u.respond(http.StatusOK, "", info)
//}

func (server *UserController) Login() {
	var err error
	body := server.Ctx.Input.RequestBody
	loginRequest := models.LoginRequest{}
	err = json.Unmarshal(body, &loginRequest) //解析body中数据
	logs.Debug("loginRequest", loginRequest)
	if err != nil {
		server.respond(models.NoRequestBody, err.Error())
		return
	}
	//if len(user.Username) == 0 {
	//	u.respond(http.StatusBadRequest, "用户名不能为空")
	//	return
	//}

	claims, err := jose.Verify("https://api-auth.web3auth.io/jwks", loginRequest.Web3AuthPublicKey, loginRequest.IdToken)
	if err != nil {
		server.respond(models.LoginFailed, "Web3Auth verify failed")
		return
	}

	if claims.Exp <= time.Now().Unix() {
		server.respond(models.LoginFailed, "Web3Auth login expire")
		return
	}

	clientId, _ := beego.AppConfig.String("web3ClientId")

	if strings.Compare(clientId, claims.Aud) != 0 {
		server.respond(models.LoginFailed, "web3auth clientId error")
		return
	}

	user := loginRequest.RequstUser
	//if len(user.VerifierId) == 0 {
	user.VerifierId = claims.VerifierId
	//}
	if len(user.Username) == 0 {
		user.Username = "ai_" + generatorMD5(user.Mail)[0:8]
	}
	if len(user.ProfileImage) == 0 {
		user.ProfileImage = claims.ProfileImage
	}

	checkUser := &models.User{Mail: user.Mail}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUser, "mail")
	stats := mysql.Ping()
	if stats != nil {
		server.respond(models.CreateUserFailed, stats.Error())
		return
	}
	if err != nil && stats == nil {
		checkUser, err = regisgerUser(user)
		if checkUser == nil {
			server.respond(models.CreateUserFailed, err.Error())
			return
		}
	}
	key := "token:user-" + strconv.Itoa(checkUser.Id)
	session, err := redis.GetDataToString(key)
	if session != "" {
		server.Ctx.Output.Header("Authorization", session)
		server.respond(http.StatusOK, "")
		return
	}
	//duration := claims.Exp - claims.Iat
	duration := utils.DEFAULT_EXPIRE_SECONDS
	tokenString, err := utils.GenerateToken(checkUser, checkUser.Id, duration)
	if err != nil {
		server.respond(models.LoginFailed, "failed")
		return
	}
	//checkUserLevel := &models.UserLevel{Id: checkUser.LevelId}
	//err = mysql.GetMysqlInstace().Ormer.Read(checkUserLevel)
	//if err != nil {
	//	logs.Debug("Recharge 用户等级查找失败")
	//} else {
	//	odysseus.SyncCredit(strconv.Itoa(checkUser.Id), checkUserLevel.FreeQuota)
	//}

	//u.SetSession(key, tokenString)
	//lifetime, _ := beego.AppConfig.Int64("sessiongcmaxlifetime")
	redis.SetKeyAndData(key, tokenString, time.Duration(duration)*time.Second)
	server.Ctx.Output.Header("Authorization", tokenString) // set token into header
	//session := u.GetSession(key)
	//logs.Debug("login session", session)

	server.respond(http.StatusOK, "")
}

func (server *UserController) Logout() {
	token, err := server.Check()
	if err != nil {
		server.respond(http.StatusUnauthorized, err.Error())
		return
	}
	key := "token:user-" + strconv.Itoa(token.UserID)
	redis.DeleteKey(key)
	server.respond(http.StatusUnauthorized, "")
}

func (server *UserController) Regisger() {
	var err error
	user := models.User{}
	body := server.Ctx.Input.RequestBody
	err = json.Unmarshal(body, &user) //解析body中数据
	logs.Debug("user", user)
	if err != nil {
		server.respond(models.NoRequestBody, err.Error())
		return
	}
	//if len(user.Username) == 0 {
	//	server.respond(models.MissingParameter, "Missing username parameter")
	//	return
	//}
	//}
	if len(user.Username) == 0 {
		user.Username = "ai_" + generatorMD5(user.Mail)[0:8]
	}
	user.VerifierId = user.Username

	_, err = regisgerUser(user)
	if err != nil {
		server.respond(models.CreateUserFailed, err.Error())
		return
	}

	server.respond(http.StatusOK, "")
}

func (server *UserController) UserInfo() {
	token, err := server.Check()
	if err != nil {
		server.respond(http.StatusUnauthorized, err.Error())
		return
	}
	checkUser := &models.User{Id: token.UserID}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUser)
	if err != nil {
		server.respond(models.BusinessFailed, err.Error())
		return
	}
	userBalance := checkUser.Balance
	balance, err := odysseus.GetUserBalance(int64(checkUser.Id))
	if err == nil {
		userBalance = balance
		if balance < 0 {
			userBalance = 0
		}
	}

	checkUserLevel := &models.UserLevel{Level: checkUser.Level}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUserLevel, "level")
	creditQuota := int64(0)
	if err == nil {
		creditQuota = checkUserLevel.CreditQuota
		if balance < 0 {
			creditQuota = checkUserLevel.CreditQuota + balance
		}
	}

	userInfo := models.UserInfo{
		Id:           checkUser.Id,
		Name:         checkUser.Name,
		Username:     checkUser.Username,
		Mail:         checkUser.Mail,
		Phone:        checkUser.Phone,
		CustomId:     checkUser.CustomId,
		ChainAccount: checkUser.ChainAccount,
		Type:         checkUser.Type,
		IsAuthed:     checkUser.IsAuthed,
		Balance:      float64(userBalance / 1000000),
		Level:        checkUser.Level,
		ProfileImage: checkUser.ProfileImage,
		Role:         checkUser.Role,
		CreditQuota:  float64(creditQuota / 1000000),
	}
	server.respond(http.StatusOK, "", userInfo)
}

func (server *UserController) FreeCallCount() {
	token, err := server.Check()
	if err != nil {
		server.respond(http.StatusUnauthorized, err.Error())
		return
	}
	body := server.Ctx.Input.RequestBody
	appRequest := models.AppRequest{}
	_ = json.Unmarshal(body, &appRequest) //解析body中数据
	logs.Debug("appRequest", appRequest, string(body))

	if appRequest.Page == 0 {
		appRequest.Page = 1
	}
	if appRequest.Size == 0 {
		appRequest.Size = 10
	}
	offset := (appRequest.Page - 1) * appRequest.Size

	checkUser := &models.User{Id: token.UserID}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUser)
	if err != nil {
		server.respond(models.BusinessFailed, err.Error())
		return
	}

	checkUserLevel := &models.UserLevel{Level: checkUser.Level}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUserLevel, "level")
	if err != nil {
		server.respond(models.BusinessFailed, err.Error())
		return
	}

	countQB, _ := orm.NewQueryBuilder("mysql")
	queryQB, _ := orm.NewQueryBuilder("mysql")

	countQB.Select("count(*) AS total").
		From("task_type").Where("deleted = 0")

	cond := fmt.Sprintf("user_level_task_type.task_type_id = task_type.id and user_level_task_type.user_level = %d", checkUser.Level)
	queryQB.Select("task_type.id",
		"task_type.name",
		"task_type.api_path",
		"task_type.type",
		"task_type.category",
		"user_level_task_type.free_call_count_day",
		"user_level_task_type.free_call_count_month",
		"user_level_task_type.free_call_count_year",
		"user_level_task_type.free_call_count_total").
		From("task_type").LeftJoin("user_level_task_type").On(cond)
	queryQB.Where("task_type.deleted = 0")

	//if !(token.Role == 1 || token.Role == 2) {
	//countQB.And(fmt.Sprintf("user_id = '%d'", token.UserID))
	//queryQB.And(fmt.Sprintf("user_id = '%d'", token.UserID))
	//}
	if appRequest.Keyword != "" {
		keyword := "%" + appRequest.Keyword + "%"
		countQB.And(fmt.Sprintf("name like '%s'", keyword))
		queryQB.And(fmt.Sprintf("name like '%s'", keyword))
	}
	queryQB.Limit(int(appRequest.Size)).Offset(int(offset))

	sql := countQB.String()
	var total int64
	_ = mysql.GetMysqlInstace().Ormer.Raw(sql).QueryRow(&total)

	// 导出 SQL 语句
	type TempTaskType struct {
		Id                 int64  `json:"id"`
		Name               string `json:"name"`
		Type               int    `json:"type"`
		TypeDesc           string `json:"type_desc"`
		ApiPath            string `json:"api_path"`
		Category           int    `json:"category"`
		FreeCallCountDay   int64  `json:"free_call_count_day"`
		FreeCallCountMonth int64  `json:"free_call_count_month"`
		WeekCount          int    `json:"week_count"`
		MonthCount         int    `json:"month_count"`
		//FreeCallCountYear  int64  `json:"free_call_count_year"`
		//FreeCallCountTotal int64  `json:"free_call_count_total"`
	}
	var taskTypes []*TempTaskType
	sql = queryQB.String()
	mysql.GetMysqlInstace().Ormer.Raw(sql).QueryRows(&taskTypes)

	var ids []int64

	idsString := ""
	for _, value := range taskTypes {
		ids = append(ids, value.Id)
		idsString = idsString + "'" + strconv.Itoa(int(value.Id)) + "'" + ","
		value.TypeDesc = models.ModelType(value.Type).String()
	}
	idsString = idsString[:len(idsString)-1]
	uids := []int64{
		int64(token.UserID),
	}

	totalDayUsed := int64(0)
	totalMonthUsed := int64(0)
	uesd, err := odysseus.UserFreeUesd(uids, ids)
	if err == nil {
		userdata := uesd[int64(token.UserID)]
		totalDayUsed = userdata.TotalDayUsed
		totalMonthUsed = userdata.TotalMonthUsed
		for _, value := range taskTypes {
			taskUesd := userdata.TasksUsed[value.Id]
			value.FreeCallCountDay = value.FreeCallCountDay - taskUesd.TaskDayUsed
			value.FreeCallCountMonth = value.FreeCallCountMonth - taskUesd.TaskMonthUsed
		}
	}

	currentTime := time.Now()
	end := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 23, 59, 59, 0, time.UTC)
	temp := fmt.Sprintf("-%dh", 24*7)
	m, _ := time.ParseDuration(temp)
	tempTime := currentTime.Add(m)
	tempTime = time.Date(tempTime.Year(), tempTime.Month(), tempTime.Day(), 0, 0, 0, 0, time.UTC)
	startTime := fmt.Sprintf(tempTime.Format(format))
	endTime := fmt.Sprintf(end.Format(format))
	weekCountQB, _ := orm.NewQueryBuilder("mysql")
	weekCountQB.Select("count(*)", "type").
		From("tasks").
		Where(fmt.Sprintf("time >= '%s'", startTime)).
		And(fmt.Sprintf("time <= '%s'", endTime)).
		And(fmt.Sprintf("uid >= '%d'", token.UserID)).
		And(fmt.Sprintf("type in(%s)", idsString))
	sql = weekCountQB.String()
	weekCount, err := postgres.CountTasks(sql)
	if err == nil && weekCount != nil && len(weekCount) > 0 {
		for _, value := range taskTypes {
			count := findTaskCount(weekCount, int(value.Id))
			value.WeekCount = count
		}
	}

	temp = fmt.Sprintf("-%dh", 24*30)
	m, _ = time.ParseDuration(temp)
	tempTime = currentTime.Add(m)
	tempTime = time.Date(tempTime.Year(), tempTime.Month(), tempTime.Day(), 0, 0, 0, 0, time.UTC)
	startTime = fmt.Sprintf(tempTime.Format(format))
	monthCountQB, _ := orm.NewQueryBuilder("mysql")
	monthCountQB.Select("count(*)", "type").
		From("tasks").
		Where(fmt.Sprintf("time >= '%s'", startTime)).
		And(fmt.Sprintf("time <= '%s'", endTime)).
		And(fmt.Sprintf("uid >= '%d'", token.UserID)).
		And(fmt.Sprintf("type in(%s)", idsString))
	sql = monthCountQB.String()
	monthCount, err := postgres.CountTasks(sql)
	if err == nil && monthCount != nil && len(monthCount) > 0 {
		for _, value := range taskTypes {
			count := findTaskCount(monthCount, int(value.Id))
			value.MonthCount = count
		}
	}

	responseData := struct {
		Total              int64       `json:"total"`
		TotalDayUsed       int64       `json:"total_day_used"`
		TotalMonthUsed     int64       `json:"total_month_used"`
		FreeCallCountDay   int         `json:"free_call_count_day"`
		FreeCallCountMonth int         `json:"free_call_count_month"`
		Data               interface{} `json:"data,omitempty"`
	}{
		Total:              total,
		TotalDayUsed:       totalDayUsed,
		TotalMonthUsed:     totalMonthUsed,
		FreeCallCountDay:   checkUserLevel.FreeCallCountDay,
		FreeCallCountMonth: checkUserLevel.FreeCallCountMonth,
		Data:               taskTypes,
	}
	server.respond(http.StatusOK, "", responseData)

}

func findTaskCount(counts []models.TaskCount, id int) int {
	idString := strconv.Itoa(id)
	for _, value := range counts {
		if strings.Compare(value.Type, idString) == 0 {
			count, _ := strconv.Atoi(value.Count)
			return count
		}
	}
	return 0
}

func regisgerUser(user models.User) (*models.User, error) {
	var err error
	qs := mysql.GetMysqlInstace().Ormer.QueryTable("user")
	usernameQs := qs.Filter("mail", user.Mail)
	var count int64
	count, err = usernameQs.Count()
	if count > 0 {
		return nil, errors.New("user is exist")
	}
	//if len(user.CustomId) > 0 {
	//	customIdQs := qs.Filter("custom_id", user.CustomId)
	//	count, err = customIdQs.Count()
	//	if count > 0 {
	//		return nil, errors.New("您指定的客户id已存在")
	//	}
	//}

	user.CreatedTime = time.Now()
	user.UpdatedTime = user.CreatedTime
	user.Level = 0
	user.Role = 4

	_, err = mysql.GetMysqlInstace().Ormer.Insert(&user)
	if err != nil {
		return nil, errors.New("create user failed")
	}

	checkUser := &models.User{Username: user.Username}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUser, "username")
	if err != nil {
		return nil, errors.New("create user failed")
	}
	checkUser.CustomId = strconv.Itoa(checkUser.Id)

	data, err := kong.CreateUser(checkUser)
	if err != nil {
		mysql.GetMysqlInstace().Ormer.Delete(checkUser)
		return nil, err
	}
	if data.Id == "" {
		mysql.GetMysqlInstace().Ormer.Delete(checkUser)
		return nil, errors.New(data.Message)
	}

	checkUserLevel := &models.UserLevel{Level: checkUser.Level}
	err = mysql.GetMysqlInstace().Ormer.Read(checkUserLevel, "level")
	if err != nil {
		logs.Debug("Recharge 用户等级查找失败")
	} else {
		plugin, err := kong.SetRateLimit(checkUser, checkUserLevel, "")
		if err == nil {
			checkUser.RateLimitPluginId = plugin.Id
		}
	}

	mysql.GetMysqlInstace().Ormer.Update(checkUser)

	createApiKey(checkUser)
	createJWTToken(checkUser)

	return checkUser, nil
}

func createApiKey(checkUser *models.User) {
	data, err := kong.CreateApiKey(checkUser)
	if err != nil {
		return
	}
	if data.Id == "" {
		return
	}

	timestamp := time.Now()
	app := models.ApiKey{
		Name:        "test",
		ApiKey:      data.Key,
		UserId:      checkUser.Id,
		CreatedTime: timestamp,
		UpdatedTime: timestamp,
		Deleted:     0,
		ApiKeyId:    data.Id,
	}
	mysql.GetMysqlInstace().Ormer.Insert(&app)
}

func createJWTToken(checkUser *models.User) {
	data, store, err := kong.CreateJwt(checkUser)
	if err != nil {
		return
	}
	if data.Id == "" {
		return
	}

	jwtToken, err := utils.GenerateKongToken(data, "")
	if err != nil {
		jwtToken = ""
	}
	timestamp := time.Now()
	app := models.JwtToken{
		Name:          "test",
		JwtCredential: string(store),
		JwtToken:      jwtToken,
		UserId:        checkUser.Id,
		CreatedTime:   timestamp,
		UpdatedTime:   timestamp,
		Deleted:       0,
		JwtId:         data.Id,
	}

	mysql.GetMysqlInstace().Ormer.Insert(&app)
}

func generatorMD5(code string) string {
	MD5 := md5.New()
	_, _ = io.WriteString(MD5, code)
	return hex.EncodeToString(MD5.Sum(nil))
}
