package main

import (
	"context"
	"net/http"
	"sync"

	"golang.org/x/time/rate"
)

// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
	limiter    *rate.Limiter
	queue      chan struct{}
	maxWorkers int
	mu         sync.Mutex
}

// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
	return &RateLimiter{
		limiter:    rate.NewLimiter(rate.Limit(ratePerSecond), burst),
		queue:      make(chan struct{}, maxWorkers),
		maxWorkers: maxWorkers,
	}
}

// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
	rl.mu.Lock()
	defer rl.mu.Unlock()
	if err := rl.limiter.Wait(ctx); err != nil {
		return false
	}
	select {
	case rl.queue <- struct{}{}:
		return true
	default:
		return false
	}
}

// Release 释放并发槽
func (rl *RateLimiter) Release() {
	<-rl.queue
}

// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		if !rl.Allow(ctx) {
			http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
			return
		}
		defer rl.Release()
		next.ServeHTTP(w, r)
	})
}
