Commit c3342c58 authored by Wade's avatar Wade

add rate limit

parent b28247b8
...@@ -80,6 +80,7 @@ require ( ...@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect google.golang.org/grpc v1.72.0 // indirect
......
...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= ...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
......
...@@ -4,11 +4,13 @@ import ( ...@@ -4,11 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net/http"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
) )
func main() { func main() {
...@@ -16,7 +18,7 @@ func main() { ...@@ -16,7 +18,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ds := deepseek.DeepSeek{ ds := deepseek.DeepSeek{
APIKey:"sk-9f70df871a7c4b8aa566a3c7a0603706", APIKey: "sk-9f70df871a7c4b8aa566a3c7a0603706",
} }
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds)) g, err := genkit.Init(ctx, genkit.WithPlugins(&ds))
...@@ -24,38 +26,48 @@ func main() { ...@@ -24,38 +26,48 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
m :=ds.DefineModel(g, m := ds.DefineModel(g,
deepseek.ModelDefinition{ deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support Type: "chat", // Must be chat for tool support
}, },
nil) nil)
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
//genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`)) ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil{ if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
return return "", err
} }
fmt.Println("resp.Text()",resp.Text())
// if err != nil { fmt.Println("resp.Text()", resp.Text())
// return "", err
// }
// text := resp.Text() if err != nil {
// return text, nil return "", err
// }) }
//<-ctx.Done() text := resp.Text()
} return text, nil
})
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl := NewRateLimiter(10, 20, 5)
// 创建 Genkit HTTP 处理器
mux := http.NewServeMux()
for _, a := range genkit.ListFlows(g) {
handler := rl.Middleware(genkit.Handler(a))
mux.Handle("POST /"+a.Name(), handler)
}
// 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
const provider = "deepseek" const provider = "deepseek"
var ( var (
mediaSupportedModels = []string{deepseek.DeepSeekChat,deepseek.DeepSeekCoder,deepseek.DeepSeekReasoner} mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{ // toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
...@@ -34,154 +34,148 @@ var ( ...@@ -34,154 +34,148 @@ var (
} }
) )
// DeepSeek holds configuration for the plugin. // DeepSeek holds configuration for the plugin.
type DeepSeek struct { type DeepSeek struct {
APIKey string // DeepSeek API key APIKey string // DeepSeek API key
//ServerAddress string //ServerAddress string
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the provider name. // Name returns the provider name.
func (d DeepSeek) Name() string { func (d DeepSeek) Name() string {
return provider return provider
} }
// ModelDefinition represents a model with its name and type. // ModelDefinition represents a model with its name and type.
type ModelDefinition struct { type ModelDefinition struct {
Name string Name string
Type string Type string
} }
// // DefineModel defines a DeepSeek model in Genkit. // // DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if !d.initted { if !d.initted {
panic("deepseek.Init not called") panic("deepseek.Init not called")
} }
// Define model info, supporting multiturn and system role. // Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{ mi := ai.ModelInfo{
Label: model.Name, Label: model.Name,
Supports: &ai.ModelSupports{ Supports: &ai.ModelSupports{
Multiturn: true, Multiturn: true,
SystemRole: true, SystemRole: true,
Media: false, // DeepSeek API primarily supports text. Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation. Tools: false, // Tools not yet supported in this implementation.
}, },
Versions: []string{}, Versions: []string{},
} }
if info != nil { if info != nil {
mi = *info mi = *info
} }
meta := &ai.ModelInfo{ meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name, // Label: "DeepSeek - " + model.Name,
Label: model.Name, Label: model.Name,
Supports: mi.Supports, Supports: mi.Supports,
Versions: []string{}, Versions: []string{},
} }
gen := &generator{model: model, apiKey: d.APIKey} gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
} }
// Init initializes the DeepSeek plugin. // Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error { func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.initted { if d.initted {
panic("deepseek.Init already called") panic("deepseek.Init already called")
} }
if d == nil || d.APIKey == "" { if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey") return fmt.Errorf("deepseek: need APIKey")
} }
d.initted = true d.initted = true
return nil return nil
} }
// generator handles model generation. // generator handles model generation.
type generator struct { type generator struct {
model ModelDefinition model ModelDefinition
apiKey string apiKey string
} }
// generate implements the Genkit model generation interface. // generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// stream := cb != nil // stream := cb != nil
if len(input.Messages) == 0 { if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required") return nil, fmt.Errorf("prompt or messages required")
} }
// Set up the Deepseek client // Set up the Deepseek client
// Initialize DeepSeek client. // Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey) client := deepseek.NewClient(g.apiKey)
// Create a chat completion request // Create a chat completion request
request := &deepseek.ChatCompletionRequest{ request := &deepseek.ChatCompletionRequest{
Model: g.model.Name, Model: g.model.Name,
} }
for _, msg := range input.Messages { for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role] role, ok := roleMapping[msg.Role]
if !ok { if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role) return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
} }
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
}
// Send the request and handle the response // Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request) response, err := client.CreateChatCompletion(ctx, request)
if err != nil { if err != nil {
log.Fatalf("error: %v", err) log.Fatalf("error: %v", err)
} }
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
for _, chunk := range response.Choices { // Print the response
p := ai.Part{ fmt.Println("Response:", response.Choices[0].Message.Content)
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index), // Create a final response with the merged chunks
} finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
finalResponse.Message.Content = append(finalResponse.Message.Content,&p) for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
} }
return finalResponse, nil // Return the final merged response
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
} }
// concatMessageParts concatenates message parts into a single string. // concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string { func concatMessageParts(parts []*ai.Part) string {
var sb strings.Builder var sb strings.Builder
for _, part := range parts { for _, part := range parts {
if part.IsText() { if part.IsText() {
sb.WriteString(part.Text) sb.WriteString(part.Text)
} }
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them. // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
} }
return sb.String() return sb.String()
} }
/* /*
// Choice represents a completion choice generated by the model. // Choice represents a completion choice generated by the model.
...@@ -205,5 +199,3 @@ type Part struct { ...@@ -205,5 +199,3 @@ type Part struct {
} }
*/ */
...@@ -71,129 +71,129 @@ import ( ...@@ -71,129 +71,129 @@ import (
// Client 知识库客户端 // Client 知识库客户端
type Client struct { type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670" BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
} }
// SpaceRequest 创建空间的请求结构体 // SpaceRequest 创建空间的请求结构体
type SpaceRequest struct { type SpaceRequest struct {
ID int `json:"id"` ID int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
VectorType string `json:"vector_type"` VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"` DomainType string `json:"domain_type"`
Desc string `json:"desc"` Desc string `json:"desc"`
Owner string `json:"owner"` Owner string `json:"owner"`
SpaceID int `json:"space_id"` SpaceID int `json:"space_id"`
} }
// DocumentRequest 添加文档的请求结构体 // DocumentRequest 添加文档的请求结构体
type DocumentRequest struct { type DocumentRequest struct {
DocName string `json:"doc_name"` DocName string `json:"doc_name"`
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
DocType string `json:"doc_type"` DocType string `json:"doc_type"`
DocToken string `json:"doc_token"` DocToken string `json:"doc_token"`
Content string `json:"content"` Content string `json:"content"`
Source string `json:"source"` Source string `json:"source"`
Labels string `json:"labels"` Labels string `json:"labels"`
Questions []string `json:"questions"` Questions []string `json:"questions"`
} }
// ChunkParameters 分片参数 // ChunkParameters 分片参数
type ChunkParameters struct { type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"` ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"` TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"` SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"` ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"` ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"` Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"` EnableMerge bool `json:"enable_merge"`
} }
// SyncBatchRequest 同步批处理的请求结构体 // SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct { type SyncBatchRequest struct {
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
SpaceID string `json:"space_id"` SpaceID string `json:"space_id"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"` ChunkParameters ChunkParameters `json:"chunk_parameters"`
} }
// NewClient 创建新的客户端实例 // NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client { func NewClient(ip string, port int) *Client {
return &Client{ return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port), BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
} }
} }
// AddSpace 创建知识空间 // AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) { func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL) url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
// AddDocument 添加文档 // AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) { func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID) url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
// SyncBatchDocument 同步批处理文档 // SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) { func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID) url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
This diff is collapsed.
This diff is collapsed.
...@@ -12,188 +12,188 @@ import ( ...@@ -12,188 +12,188 @@ import (
) )
var ( var (
connString = flag.String("dbconn", "", "database connection string") connString = flag.String("dbconn", "", "database connection string")
) )
// QA 结构体表示 qa 表的记录 // QA 结构体表示 qa 表的记录
type QA struct { type QA struct {
ID int64 // 主键 ID int64 // 主键
CreatedAt time.Time // 创建时间 CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名 Username *string // 可空的用户名
Question *string // 可空的问题 Question *string // 可空的问题
Answer *string // 可空的答案 Answer *string // 可空的答案
} }
// QAStore 定义 DAO 接口 // QAStore 定义 DAO 接口
type QAStore interface { type QAStore interface {
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录 // GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error) WriteQA(ctx context.Context, qa QA) (int64, error)
} }
// qaStore 是 QAStore 接口的实现 // qaStore 是 QAStore 接口的实现
type qaStore struct { type qaStore struct {
db *sql.DB db *sql.DB
} }
// NewQAStore 创建新的 QAStore 实例 // NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore { func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db} return &qaStore{db: db}
} }
// GetLatestQA 从 latest_qa 视图读取数据 // GetLatestQA 从 latest_qa 视图读取数据
func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) { func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) {
query := ` query := `
SELECT id, created_at, user_id, username, question, answer SELECT id, created_at, user_id, username, question, answer
FROM latest_qa FROM latest_qa
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)` WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
args := []interface{}{userID} args := []interface{}{userID}
if userID == nil { if userID == nil {
args = []interface{}{nil} args = []interface{}{nil}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err) return nil, fmt.Errorf("query latest_qa: %w", err)
} }
defer rows.Close() defer rows.Close()
var results []QA var results []QA
for rows.Next() { for rows.Next() {
var qa QA var qa QA
var userIDVal sql.NullInt64 var userIDVal sql.NullInt64
var username, question, answer sql.NullString var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil { if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err) return nil, fmt.Errorf("scan row: %w", err)
} }
if userIDVal.Valid { if userIDVal.Valid {
qa.UserID = &userIDVal.Int64 qa.UserID = &userIDVal.Int64
} }
if username.Valid { if username.Valid {
qa.Username = &username.String qa.Username = &username.String
} }
if question.Valid { if question.Valid {
qa.Question = &question.String qa.Question = &question.String
} }
if answer.Valid { if answer.Valid {
qa.Answer = &answer.String qa.Answer = &answer.String
} }
results = append(results, qa) results = append(results, qa)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err) return nil, fmt.Errorf("row iteration: %w", err)
} }
return results, nil return results, nil
} }
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) { func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 { if qa.ID != 0 {
// 更新记录 // 更新记录
query := ` query := `
UPDATE qa UPDATE qa
SET user_id = $1, username = $2, question = $3, answer = $4 SET user_id = $1, username = $2, question = $3, answer = $4
WHERE id = $5 WHERE id = $5
RETURNING id` RETURNING id`
var updatedID int64 var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID) err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID) return 0, fmt.Errorf("no record found with id %d", qa.ID)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("update qa: %w", err) return 0, fmt.Errorf("update qa: %w", err)
} }
return updatedID, nil return updatedID, nil
} }
// 插入新记录 // 插入新记录
query := ` query := `
INSERT INTO qa (user_id, username, question, answer) INSERT INTO qa (user_id, username, question, answer)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
RETURNING id` RETURNING id`
var newID int64 var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID) err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil { if err != nil {
return 0, fmt.Errorf("insert qa: %w", err) return 0, fmt.Errorf("insert qa: %w", err)
} }
return newID, nil return newID, nil
} }
func mainQA() { func mainQA() {
flag.Parse() flag.Parse()
ctx := context.Background() ctx := context.Background()
if *connString == "" { if *connString == "" {
log.Fatal("need -dbconn") log.Fatal("need -dbconn")
} }
db, err := sql.Open("postgres", *connString) db, err := sql.Open("postgres", *connString)
if err != nil { if err != nil {
log.Fatalf("open database: %v", err) log.Fatalf("open database: %v", err)
} }
defer db.Close() defer db.Close()
store := NewQAStore(db) store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA // 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101)) results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil { if err != nil {
log.Fatalf("get latest QA: %v", err) log.Fatalf("get latest QA: %v", err)
} }
for _, qa := range results { for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n", fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer)) qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
} }
// 示例:插入新 QA // 示例:插入新 QA
newQA := QA{ newQA := QA{
UserID: int64Ptr(101), UserID: int64Ptr(101),
Username: stringPtr("alice"), Username: stringPtr("alice"),
Question: stringPtr("What is AI?"), Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."), Answer: stringPtr("AI is..."),
} }
newID, err := store.WriteQA(ctx, newQA) newID, err := store.WriteQA(ctx, newQA)
if err != nil { if err != nil {
log.Fatalf("write QA: %v", err) log.Fatalf("write QA: %v", err)
} }
fmt.Printf("Inserted QA with ID: %d\n", newID) fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA // 示例:更新 QA
updateQA := QA{ updateQA := QA{
ID: newID, ID: newID,
UserID: int64Ptr(101), UserID: int64Ptr(101),
Username: stringPtr("alice_updated"), Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"), Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."), Answer: stringPtr("NLP is..."),
} }
updatedID, err := store.WriteQA(ctx, updateQA) updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil { if err != nil {
log.Fatalf("update QA: %v", err) log.Fatalf("update QA: %v", err)
} }
fmt.Printf("Updated QA with ID: %d\n", updatedID) fmt.Printf("Updated QA with ID: %d\n", updatedID)
} }
// 辅助函数:处理指针类型的空值 // 辅助函数:处理指针类型的空值
func int64Ptr(i int64) *int64 { func int64Ptr(i int64) *int64 {
return &i return &i
} }
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }
func derefInt64(p *int64) interface{} { func derefInt64(p *int64) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
func derefString(p *string) interface{} { func derefString(p *string) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
\ No newline at end of file
package main
import (
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
limiter *rate.Limiter
queue chan struct{}
maxWorkers int
mu sync.Mutex
}
// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(ratePerSecond), burst),
queue: make(chan struct{}, maxWorkers),
maxWorkers: maxWorkers,
}
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.limiter.Wait(ctx); err != nil {
return false
}
select {
case rl.queue <- struct{}{}:
return true
default:
return false
}
}
// Release 释放并发槽
func (rl *RateLimiter) Release() {
<-rl.queue
}
// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !rl.Allow(ctx) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
defer rl.Release()
next.ServeHTTP(w, r)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment