package cachedata

import (
	"fmt"
	"github.com/odysseus/payment/model"
	goredislib "github.com/redis/go-redis/v9"
	log "github.com/sirupsen/logrus"
	"strconv"
	"time"
)

type QueryResult struct {
	Task  *model.TaskType
	Error error
}

type QueryParam struct {
	path   string
	uid    int64
	result chan QueryResult
}

func (c *CacheData) queryRoutine() {
	for {
		select {
		case <-c.ctx.Done():
			return
		case param := <-c.taskChan:
			c.doQuery(param.path, param.uid, param.result)
		}
	}
}

func (c *CacheData) MQuery(paths []string, uids []int64) (chan QueryResult, error) {
	if len(paths) != len(uids) {
		return nil, ErrInvalidParam
	}
	result := make(chan QueryResult, len(paths))

	for i := 0; i < len(paths); i++ {
		c.taskChan <- QueryParam{path: paths[i], uid: uids[i]}
	}
	return result, nil
}

func (c *CacheData) doQuery(path string, uid int64, result chan QueryResult) {
	task, err := c.Query(path, uid)
	select {
	case result <- QueryResult{Task: task, Error: err}:
	default:
		log.Errorf("failed to send query result to channel")
	}
}

func (c *CacheData) Query(path string, uid int64) (*model.TaskType, error) {
	locked, release, _ := c.getUserLockWithRetry(uid, USER_INFO_LOCK_DURATION*10)
	if !locked {
		return nil, ErrWaitLockTimeout
	}
	defer release()

	// 1. get user info.
	user, err := c.GetUserInfo(uid)
	if err != nil {
		log.WithError(err).Error("failed to get user info")
		return nil, err
	}
	if user.Deleted == 1 {
		log.WithError(err).Error("user is deleted")
		return nil, ErrUserDeleted
	}
	// 2. get task info.
	task, err := c.GetTaskWithPath(path)
	if err != nil {
		log.WithError(err).Error("failed to get task info")
		return nil, err
	}
	// 3. get user level info.
	userLevel, err := c.GetUserLevelInfoByLevelId(int64(user.Level))
	if err != nil {
		log.WithError(err).Error("failed to get user level info")
		return nil, err
	}

	userLevelAndTaskType, err := c.GetUserLevelAndTaskTypeByLevelIdAndTaskTypeId(int64(user.Level), int64(task.ID))
	if err != nil {
		log.WithError(err).Error("failed to get user level and task type info")
	}

	// 4. check if user can do this task.
	{
		// a. check free times for this user and this task.
		passed, err := c.checkQueryForFreeTimes(uid, userLevel, userLevelAndTaskType)
		if err != nil {
			log.WithError(err).Error("failed to check free times")
			return nil, err
		}
		if passed {
			// cost a free time.
			c.costFreeTime(uid, user, userLevel, task, userLevelAndTaskType)
			return task, nil
		}
		// b. continue check balance is enough for task.Fee
		passed, err = c.checkQueryForCost(uid, user, userLevel, task)
		if err != nil {
			log.WithError(err).Error("failed to check cost")
			return nil, err
		}
		if passed {
			// cost charge.
			c.costCharge(uid, task.Price)
			return task, nil
		}

	}
	return nil, ErrBalanceNotEnough
}

func (c *CacheData) checkQueryForFreeTimes(uid int64, userLevel *model.UserLevel, taskAndUserLevel *model.UserLevelTaskType) (bool, error) {
	layoutDay := "2006-01-02"
	layoutMonth := "2006-01"
	var (
		userDayFreeMax   = userLevel.FreeCallCountDay
		userMonthFreeMax = userLevel.FreeCallCountMonth
		taskFreeMax      = 0
	)
	if taskAndUserLevel != nil {
		taskFreeMax = int(taskAndUserLevel.FreeCallCountDay)
	}

	var (
		userDayKey     = fmt.Sprintf("k-u-%d:%s:", uid, time.Now().Format(layoutDay))
		userMonthKey   = fmt.Sprintf("k-u-%d:%s:", uid, time.Now().Format(layoutMonth))
		taskUserDayKey = fmt.Sprintf("k-t-%d-u-%d:%s:", taskAndUserLevel.TaskTypeId, uid, time.Now().Format(layoutDay))
		taskUserMonth  = fmt.Sprintf("k-t-%d-u-%d:%s:", taskAndUserLevel.TaskTypeId, uid, time.Now().Format(layoutMonth))
	)

	pip := c.rdb.Pipeline()
	userDayCmd := pip.Get(c.ctx, userDayKey)
	userMonthCmd := pip.Get(c.ctx, userMonthKey)
	taskUserDayCmd := pip.Get(c.ctx, taskUserDayKey)
	taskUserMonthCmd := pip.Get(c.ctx, taskUserMonth)

	_, err := pip.Exec(c.ctx)
	if err != nil {
		return false, err
	}

	var (
		userDayUsed       = int(userDayFreeMax)
		userMonthUsed     = 0
		taskUserDayUsed   = 0
		taskUserMonthUsed = 0
	)

	if userDayCmd.Err() == nil {
		userDayUsed, _ = strconv.Atoi(userDayCmd.Val())
	} else if userDayCmd.Err() == goredislib.Nil {
		userDayUsed = 0
	}

	if userMonthCmd.Err() == nil {
		userMonthUsed, _ = strconv.Atoi(userMonthCmd.Val())
	} else if userMonthCmd.Err() == goredislib.Nil {
		userMonthUsed = 0
	}

	if taskUserDayCmd.Err() == nil {
		taskUserDayUsed, _ = strconv.Atoi(taskUserDayCmd.Val())
	} else if taskUserDayCmd.Err() == goredislib.Nil {
		taskUserDayUsed = 0
	}

	if taskUserMonthCmd.Err() == nil {
		taskUserMonthUsed, _ = strconv.Atoi(taskUserMonthCmd.Val())
	} else if taskUserMonthCmd.Err() == goredislib.Nil {
		taskUserMonthUsed = 0
	}

	// do count check.
	if userDayUsed >= int(userDayFreeMax) {
		return false, nil
	}
	if userMonthUsed >= int(userMonthFreeMax) {
		return false, nil
	}

	if taskUserDayUsed >= taskFreeMax {
		return false, nil
	}
	if taskUserMonthUsed >= taskFreeMax {
		return false, nil
	}

	return true, nil
}

func (c *CacheData) checkQueryForCost(uid int64, user *UserInfo, userLevel *model.UserLevel, task *model.TaskType) (bool, error) {
	chargeKey := fmt.Sprintf("charge-%d:", uid)
	balKey := fmt.Sprintf("bal-%d:", uid)
	pip := c.rdb.Pipeline()
	chargeCmd := pip.Get(c.ctx, chargeKey)
	balCmd := pip.Get(c.ctx, balKey)
	_, err := pip.Exec(c.ctx)
	if err != nil {
		return false, err
	}
	var (
		charge = int64(0)
		bal    = int64(0)
		credit = userLevel.CreditQuota
	)
	if chargeCmd.Err() == nil {
		charge, _ = strconv.ParseInt(chargeCmd.Val(), 10, 64)
	} else if chargeCmd.Err() == goredislib.Nil {
		charge = 0
	}

	if balCmd.Err() == nil {
		bal, _ = strconv.ParseInt(balCmd.Val(), 10, 64)
	} else if balCmd.Err() == goredislib.Nil {
		bal = 0
	}

	if (charge + task.Price) <= (bal + credit) {
		return true, nil
	}

	return false, ErrBalanceNotEnough
}

func (c *CacheData) costFreeTime(uid int64, user *UserInfo, userLevel *model.UserLevel, task *model.TaskType, taskAndUserLevel *model.UserLevelTaskType) error {
	layoutDay := "2006-01-02"
	layoutMonth := "2006-01"
	var (
		userDayKey     = fmt.Sprintf("k-u-%d:%s:", uid, time.Now().Format(layoutDay))
		userMonthKey   = fmt.Sprintf("k-u-%d:%s:", uid, time.Now().Format(layoutMonth))
		taskUserDayKey = fmt.Sprintf("k-t-%d-u-%d:%s:", taskAndUserLevel.TaskTypeId, uid, time.Now().Format(layoutDay))
		taskUserMonth  = fmt.Sprintf("k-t-%d-u-%d:%s:", taskAndUserLevel.TaskTypeId, uid, time.Now().Format(layoutMonth))
	)
	pip := c.rdb.Pipeline()
	userDayKeyCmd := pip.Incr(c.ctx, userDayKey)
	userMonthCmd := pip.Incr(c.ctx, userMonthKey)
	taskUserDayCmd := pip.Incr(c.ctx, taskUserDayKey)
	taskUserMonthCmd := pip.Incr(c.ctx, taskUserMonth)

	_, err := pip.Exec(c.ctx)
	if err != nil {
		return err
	}
	expip := c.rdb.Pipeline()
	if userDayKeyCmd.Val() == 1 {
		expip.Expire(c.ctx, userDayKey, time.Hour*24)
	}
	if userMonthCmd.Val() == 1 {
		expip.Expire(c.ctx, userMonthKey, time.Hour*24*30)
	}
	if taskUserDayCmd.Val() == 1 {
		expip.Expire(c.ctx, taskUserDayKey, time.Hour*24)
	}
	if taskUserMonthCmd.Val() == 1 {
		expip.Expire(c.ctx, taskUserMonth, time.Hour*24*30)
	}
	_, err = expip.Exec(c.ctx)
	if err != nil {
		log.WithError(err).Error("failed to set expire")
	}

	return nil
}

func (c *CacheData) costCharge(uid int64, fee int64) error {
	// todo: just incr charge.
	chargeKey := fmt.Sprintf("charge-%d:", uid)
	_, err := c.rdb.IncrBy(c.ctx, chargeKey, fee).Result()
	return err
}

func (c *CacheData) costForFee(uid int64, fee int64) error {
	// todo: decr charge and balance.
	chargeKey := fmt.Sprintf("charge-%d:", uid)
	balKey := fmt.Sprintf("bal-%d:", uid)
	if fee > 0 {
		txp := c.rdb.TxPipeline()
		txp.DecrBy(c.ctx, chargeKey, fee)
		txp.DecrBy(c.ctx, balKey, fee)
		_, err := txp.Exec(c.ctx)
		return err
	}
	if fee < 0 {
		txp := c.rdb.TxPipeline()
		txp.DecrBy(c.ctx, chargeKey, fee)
		_, err := txp.Exec(c.ctx)
		return err
	}
	return nil
}

func (c *CacheData) PayforFee(uid int64, fee int64) error {
	locked, release, _ := c.getUserLockWithRetry(uid, USER_INFO_LOCK_DURATION*10)
	if !locked {
		return ErrWaitLockTimeout
	}
	defer release()
	return c.costForFee(uid, fee)
}

func (c *CacheData) RollbackForFee(uid int64, fee int64) error {
	locked, release, _ := c.getUserLockWithRetry(uid, USER_INFO_LOCK_DURATION*10)
	if !locked {
		return ErrWaitLockTimeout
	}
	defer release()
	return c.costForFee(uid, -fee)
}
