Commit c3342c58 authored by Wade's avatar Wade

add rate limit

parent b28247b8
...@@ -80,6 +80,7 @@ require ( ...@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect google.golang.org/grpc v1.72.0 // indirect
......
...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= ...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
......
...@@ -4,11 +4,13 @@ import ( ...@@ -4,11 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net/http"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
) )
func main() { func main() {
...@@ -16,7 +18,7 @@ func main() { ...@@ -16,7 +18,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ds := deepseek.DeepSeek{ ds := deepseek.DeepSeek{
APIKey:"sk-9f70df871a7c4b8aa566a3c7a0603706", APIKey: "sk-9f70df871a7c4b8aa566a3c7a0603706",
} }
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds)) g, err := genkit.Init(ctx, genkit.WithPlugins(&ds))
...@@ -24,7 +26,7 @@ func main() { ...@@ -24,7 +26,7 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
m :=ds.DefineModel(g, m := ds.DefineModel(g,
deepseek.ModelDefinition{ deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support Type: "chat", // Must be chat for tool support
...@@ -32,30 +34,40 @@ func main() { ...@@ -32,30 +34,40 @@ func main() {
nil) nil)
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
//genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`)) ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil{ if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
return return "", err
} }
fmt.Println("resp.Text()",resp.Text()) fmt.Println("resp.Text()", resp.Text())
if err != nil {
return "", err
}
// if err != nil { text := resp.Text()
// return "", err return text, nil
// } })
// text := resp.Text()
// return text, nil
// })
//<-ctx.Done()
}
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl := NewRateLimiter(10, 20, 5)
// 创建 Genkit HTTP 处理器
mux := http.NewServeMux()
for _, a := range genkit.ListFlows(g) {
handler := rl.Middleware(genkit.Handler(a))
mux.Handle("POST /"+a.Name(), handler)
}
// 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
const provider = "deepseek" const provider = "deepseek"
var ( var (
mediaSupportedModels = []string{deepseek.DeepSeekChat,deepseek.DeepSeekCoder,deepseek.DeepSeekReasoner} mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{ // toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
...@@ -34,7 +34,6 @@ var ( ...@@ -34,7 +34,6 @@ var (
} }
) )
// DeepSeek holds configuration for the plugin. // DeepSeek holds configuration for the plugin.
type DeepSeek struct { type DeepSeek struct {
APIKey string // DeepSeek API key APIKey string // DeepSeek API key
...@@ -44,14 +43,11 @@ type DeepSeek struct { ...@@ -44,14 +43,11 @@ type DeepSeek struct {
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the provider name. // Name returns the provider name.
func (d DeepSeek) Name() string { func (d DeepSeek) Name() string {
return provider return provider
} }
// ModelDefinition represents a model with its name and type. // ModelDefinition represents a model with its name and type.
type ModelDefinition struct { type ModelDefinition struct {
Name string Name string
...@@ -112,7 +108,6 @@ type generator struct { ...@@ -112,7 +108,6 @@ type generator struct {
apiKey string apiKey string
} }
// generate implements the Genkit model generation interface. // generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
...@@ -164,7 +159,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun ...@@ -164,7 +159,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
Kind: ai.PartKind(chunk.Index), Kind: ai.PartKind(chunk.Index),
} }
finalResponse.Message.Content = append(finalResponse.Message.Content,&p) finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
} }
return finalResponse, nil // Return the final merged response return finalResponse, nil // Return the final merged response
} }
...@@ -181,7 +176,6 @@ func concatMessageParts(parts []*ai.Part) string { ...@@ -181,7 +176,6 @@ func concatMessageParts(parts []*ai.Part) string {
return sb.String() return sb.String()
} }
/* /*
// Choice represents a completion choice generated by the model. // Choice represents a completion choice generated by the model.
...@@ -205,5 +199,3 @@ type Part struct { ...@@ -205,5 +199,3 @@ type Part struct {
} }
*/ */
...@@ -424,5 +424,3 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai ...@@ -424,5 +424,3 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Documents: docs, Documents: docs,
}, nil }, nil
} }
...@@ -185,4 +185,3 @@ func TestMilvusIntegration(t *testing.T) { ...@@ -185,4 +185,3 @@ func TestMilvusIntegration(t *testing.T) {
assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}") assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
}) })
} }
package main
import (
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
limiter *rate.Limiter
queue chan struct{}
maxWorkers int
mu sync.Mutex
}
// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(ratePerSecond), burst),
queue: make(chan struct{}, maxWorkers),
maxWorkers: maxWorkers,
}
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.limiter.Wait(ctx); err != nil {
return false
}
select {
case rl.queue <- struct{}{}:
return true
default:
return false
}
}
// Release 释放并发槽
func (rl *RateLimiter) Release() {
<-rl.queue
}
// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !rl.Allow(ctx) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
defer rl.Release()
next.ServeHTTP(w, r)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment