Commit c3342c58 authored by Wade's avatar Wade

add rate limit

parent b28247b8
......@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect
......
......@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
......
......@@ -4,11 +4,13 @@ import (
"context"
"fmt"
"log"
"net/http"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
)
func main() {
......@@ -16,7 +18,7 @@ func main() {
ctx := context.Background()
ds := deepseek.DeepSeek{
APIKey:"sk-9f70df871a7c4b8aa566a3c7a0603706",
APIKey: "sk-9f70df871a7c4b8aa566a3c7a0603706",
}
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds))
......@@ -24,38 +26,48 @@ func main() {
log.Fatal(err)
}
m :=ds.DefineModel(g,
m := ds.DefineModel(g,
deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support
Type: "chat", // Must be chat for tool support
},
nil)
// Define a simple flow that generates jokes about a given topic
//genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil{
fmt.Println(err.Error())
return
}
fmt.Println("resp.Text()",resp.Text())
if err != nil {
fmt.Println(err.Error())
return "", err
}
// if err != nil {
// return "", err
// }
fmt.Println("resp.Text()", resp.Text())
// text := resp.Text()
// return text, nil
// })
if err != nil {
return "", err
}
//<-ctx.Done()
}
text := resp.Text()
return text, nil
})
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl := NewRateLimiter(10, 20, 5)
// 创建 Genkit HTTP 处理器
mux := http.NewServeMux()
for _, a := range genkit.ListFlows(g) {
handler := rl.Middleware(genkit.Handler(a))
mux.Handle("POST /"+a.Name(), handler)
}
// 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
......@@ -16,7 +16,7 @@ import (
const provider = "deepseek"
var (
mediaSupportedModels = []string{deepseek.DeepSeekChat,deepseek.DeepSeekCoder,deepseek.DeepSeekReasoner}
mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
......@@ -34,154 +34,148 @@ var (
}
)
// DeepSeek holds configuration for the plugin.
type DeepSeek struct {
APIKey string // DeepSeek API key
//ServerAddress string
APIKey string // DeepSeek API key
//ServerAddress string
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
// Name returns the provider name.
func (d DeepSeek) Name() string {
return provider
return provider
}
// ModelDefinition represents a model with its name and type.
type ModelDefinition struct {
Name string
Type string
Name string
Type string
}
// // DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
d.mu.Lock()
defer d.mu.Unlock()
if !d.initted {
panic("deepseek.Init not called")
}
// Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation.
},
Versions: []string{},
}
if info != nil {
mi = *info
}
meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name,
d.mu.Lock()
defer d.mu.Unlock()
if !d.initted {
panic("deepseek.Init not called")
}
// Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation.
},
Versions: []string{},
}
if info != nil {
mi = *info
}
meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name,
Label: model.Name,
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
}
// Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.initted {
panic("deepseek.Init already called")
}
if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
return nil
d.mu.Lock()
defer d.mu.Unlock()
if d.initted {
panic("deepseek.Init already called")
}
if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
return nil
}
// generator handles model generation.
type generator struct {
model ModelDefinition
apiKey string
model ModelDefinition
apiKey string
}
// generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// stream := cb != nil
if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required")
}
if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required")
}
// Set up the Deepseek client
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
}
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
}
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Fatalf("error: %v", err)
}
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Fatalf("error: %v", err)
}
for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
finalResponse.Message.Content = append(finalResponse.Message.Content,&p)
for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
return finalResponse, nil // Return the final merged response
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
}
// concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string {
var sb strings.Builder
for _, part := range parts {
if part.IsText() {
sb.WriteString(part.Text)
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
return sb.String()
var sb strings.Builder
for _, part := range parts {
if part.IsText() {
sb.WriteString(part.Text)
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
return sb.String()
}
/*
// Choice represents a completion choice generated by the model.
......@@ -205,5 +199,3 @@ type Part struct {
}
*/
......@@ -71,129 +71,129 @@ import (
// Client 知识库客户端
type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
}
// SpaceRequest 创建空间的请求结构体
type SpaceRequest struct {
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
}
// DocumentRequest 添加文档的请求结构体
type DocumentRequest struct {
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
}
// ChunkParameters 分片参数
type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
}
// SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct {
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
}
// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
}
// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
This diff is collapsed.
This diff is collapsed.
......@@ -12,188 +12,188 @@ import (
)
var (
connString = flag.String("dbconn", "", "database connection string")
connString = flag.String("dbconn", "", "database connection string")
)
// QA 结构体表示 qa 表的记录
type QA struct {
ID int64 // 主键
CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名
Question *string // 可空的问题
Answer *string // 可空的答案
ID int64 // 主键
CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名
Question *string // 可空的问题
Answer *string // 可空的答案
}
// QAStore 定义 DAO 接口
type QAStore interface {
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error)
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error)
}
// qaStore 是 QAStore 接口的实现
type qaStore struct {
db *sql.DB
db *sql.DB
}
// NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db}
return &qaStore{db: db}
}
// GetLatestQA 从 latest_qa 视图读取数据
func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) {
query := `
query := `
SELECT id, created_at, user_id, username, question, answer
FROM latest_qa
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
args := []interface{}{userID}
if userID == nil {
args = []interface{}{nil}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err)
}
defer rows.Close()
var results []QA
for rows.Next() {
var qa QA
var userIDVal sql.NullInt64
var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
if userIDVal.Valid {
qa.UserID = &userIDVal.Int64
}
if username.Valid {
qa.Username = &username.String
}
if question.Valid {
qa.Question = &question.String
}
if answer.Valid {
qa.Answer = &answer.String
}
results = append(results, qa)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err)
}
return results, nil
args := []interface{}{userID}
if userID == nil {
args = []interface{}{nil}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err)
}
defer rows.Close()
var results []QA
for rows.Next() {
var qa QA
var userIDVal sql.NullInt64
var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
if userIDVal.Valid {
qa.UserID = &userIDVal.Int64
}
if username.Valid {
qa.Username = &username.String
}
if question.Valid {
qa.Question = &question.String
}
if answer.Valid {
qa.Answer = &answer.String
}
results = append(results, qa)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err)
}
return results, nil
}
// WriteQA 插入或更新 qa 表记录
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 {
// 更新记录
query := `
if qa.ID != 0 {
// 更新记录
query := `
UPDATE qa
SET user_id = $1, username = $2, question = $3, answer = $4
WHERE id = $5
RETURNING id`
var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID)
}
if err != nil {
return 0, fmt.Errorf("update qa: %w", err)
}
return updatedID, nil
}
// 插入新记录
query := `
var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID)
}
if err != nil {
return 0, fmt.Errorf("update qa: %w", err)
}
return updatedID, nil
}
// 插入新记录
query := `
INSERT INTO qa (user_id, username, question, answer)
VALUES ($1, $2, $3, $4)
RETURNING id`
var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil {
return 0, fmt.Errorf("insert qa: %w", err)
}
return newID, nil
var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil {
return 0, fmt.Errorf("insert qa: %w", err)
}
return newID, nil
}
func mainQA() {
flag.Parse()
ctx := context.Background()
if *connString == "" {
log.Fatal("need -dbconn")
}
db, err := sql.Open("postgres", *connString)
if err != nil {
log.Fatalf("open database: %v", err)
}
defer db.Close()
store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil {
log.Fatalf("get latest QA: %v", err)
}
for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
}
// 示例:插入新 QA
newQA := QA{
UserID: int64Ptr(101),
Username: stringPtr("alice"),
Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."),
}
newID, err := store.WriteQA(ctx, newQA)
if err != nil {
log.Fatalf("write QA: %v", err)
}
fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA
updateQA := QA{
ID: newID,
UserID: int64Ptr(101),
Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."),
}
updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil {
log.Fatalf("update QA: %v", err)
}
fmt.Printf("Updated QA with ID: %d\n", updatedID)
flag.Parse()
ctx := context.Background()
if *connString == "" {
log.Fatal("need -dbconn")
}
db, err := sql.Open("postgres", *connString)
if err != nil {
log.Fatalf("open database: %v", err)
}
defer db.Close()
store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil {
log.Fatalf("get latest QA: %v", err)
}
for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
}
// 示例:插入新 QA
newQA := QA{
UserID: int64Ptr(101),
Username: stringPtr("alice"),
Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."),
}
newID, err := store.WriteQA(ctx, newQA)
if err != nil {
log.Fatalf("write QA: %v", err)
}
fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA
updateQA := QA{
ID: newID,
UserID: int64Ptr(101),
Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."),
}
updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil {
log.Fatalf("update QA: %v", err)
}
fmt.Printf("Updated QA with ID: %d\n", updatedID)
}
// 辅助函数:处理指针类型的空值
func int64Ptr(i int64) *int64 {
return &i
return &i
}
func stringPtr(s string) *string {
return &s
return &s
}
func derefInt64(p *int64) interface{} {
if p == nil {
return nil
}
return *p
if p == nil {
return nil
}
return *p
}
func derefString(p *string) interface{} {
if p == nil {
return nil
}
return *p
}
\ No newline at end of file
if p == nil {
return nil
}
return *p
}
package main
import (
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
limiter *rate.Limiter
queue chan struct{}
maxWorkers int
mu sync.Mutex
}
// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(ratePerSecond), burst),
queue: make(chan struct{}, maxWorkers),
maxWorkers: maxWorkers,
}
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.limiter.Wait(ctx); err != nil {
return false
}
select {
case rl.queue <- struct{}{}:
return true
default:
return false
}
}
// Release 释放并发槽
func (rl *RateLimiter) Release() {
<-rl.queue
}
// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !rl.Allow(ctx) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
defer rl.Release()
next.ServeHTTP(w, r)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment