Commit c3342c58 authored by Wade's avatar Wade

add rate limit

parent b28247b8
......@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect
......
......@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
......
......@@ -4,11 +4,13 @@ import (
"context"
"fmt"
"log"
"net/http"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
)
func main() {
......@@ -16,7 +18,7 @@ func main() {
ctx := context.Background()
ds := deepseek.DeepSeek{
APIKey:"sk-9f70df871a7c4b8aa566a3c7a0603706",
APIKey: "sk-9f70df871a7c4b8aa566a3c7a0603706",
}
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds))
......@@ -24,38 +26,48 @@ func main() {
log.Fatal(err)
}
m :=ds.DefineModel(g,
m := ds.DefineModel(g,
deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support
Type: "chat", // Must be chat for tool support
},
nil)
// Define a simple flow that generates jokes about a given topic
//genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil{
fmt.Println(err.Error())
return
}
fmt.Println("resp.Text()",resp.Text())
if err != nil {
fmt.Println(err.Error())
return "", err
}
// if err != nil {
// return "", err
// }
fmt.Println("resp.Text()", resp.Text())
// text := resp.Text()
// return text, nil
// })
if err != nil {
return "", err
}
//<-ctx.Done()
}
text := resp.Text()
return text, nil
})
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl := NewRateLimiter(10, 20, 5)
// 创建 Genkit HTTP 处理器
mux := http.NewServeMux()
for _, a := range genkit.ListFlows(g) {
handler := rl.Middleware(genkit.Handler(a))
mux.Handle("POST /"+a.Name(), handler)
}
// 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
......@@ -16,7 +16,7 @@ import (
const provider = "deepseek"
var (
mediaSupportedModels = []string{deepseek.DeepSeekChat,deepseek.DeepSeekCoder,deepseek.DeepSeekReasoner}
mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
......@@ -34,154 +34,148 @@ var (
}
)
// DeepSeek holds configuration for the plugin.
type DeepSeek struct {
APIKey string // DeepSeek API key
//ServerAddress string
APIKey string // DeepSeek API key
//ServerAddress string
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
// Name returns the provider name.
func (d DeepSeek) Name() string {
return provider
return provider
}
// ModelDefinition represents a model with its name and type.
type ModelDefinition struct {
Name string
Type string
Name string
Type string
}
// // DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
d.mu.Lock()
defer d.mu.Unlock()
if !d.initted {
panic("deepseek.Init not called")
}
// Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation.
},
Versions: []string{},
}
if info != nil {
mi = *info
}
meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name,
d.mu.Lock()
defer d.mu.Unlock()
if !d.initted {
panic("deepseek.Init not called")
}
// Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation.
},
Versions: []string{},
}
if info != nil {
mi = *info
}
meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name,
Label: model.Name,
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
}
// Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.initted {
panic("deepseek.Init already called")
}
if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
return nil
d.mu.Lock()
defer d.mu.Unlock()
if d.initted {
panic("deepseek.Init already called")
}
if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
return nil
}
// generator handles model generation.
type generator struct {
model ModelDefinition
apiKey string
model ModelDefinition
apiKey string
}
// generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// stream := cb != nil
if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required")
}
if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required")
}
// Set up the Deepseek client
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
}
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
}
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Fatalf("error: %v", err)
}
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Fatalf("error: %v", err)
}
for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
finalResponse.Message.Content = append(finalResponse.Message.Content,&p)
for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
return finalResponse, nil // Return the final merged response
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
}
// concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string {
var sb strings.Builder
for _, part := range parts {
if part.IsText() {
sb.WriteString(part.Text)
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
return sb.String()
var sb strings.Builder
for _, part := range parts {
if part.IsText() {
sb.WriteString(part.Text)
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
return sb.String()
}
/*
// Choice represents a completion choice generated by the model.
......@@ -205,5 +199,3 @@ type Part struct {
}
*/
......@@ -71,129 +71,129 @@ import (
// Client 知识库客户端
type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
}
// SpaceRequest 创建空间的请求结构体
type SpaceRequest struct {
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
}
// DocumentRequest 添加文档的请求结构体
type DocumentRequest struct {
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
}
// ChunkParameters 分片参数
type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
}
// SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct {
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
}
// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
}
// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
......@@ -36,393 +36,391 @@ const provider = "milvus"
// Field names for Milvus schema.
const (
idField = "id"
vectorField = "vector"
textField = "text"
metadataField = "metadata"
idField = "id"
vectorField = "vector"
textField = "text"
metadataField = "metadata"
)
// Milvus holds configuration for the plugin.
type Milvus struct {
// Milvus server address (host:port, e.g., "localhost:19530").
// Defaults to MILVUS_ADDRESS environment variable.
Addr string
// Username for authentication.
// Defaults to MILVUS_USERNAME.
Username string
// Password for authentication.
// Defaults to MILVUS_PASSWORD.
Password string
// Token for authentication (alternative to username/password).
// Defaults to MILVUS_TOKEN.
Token string
client client.Client // Milvus client.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
// Milvus server address (host:port, e.g., "localhost:19530").
// Defaults to MILVUS_ADDRESS environment variable.
Addr string
// Username for authentication.
// Defaults to MILVUS_USERNAME.
Username string
// Password for authentication.
// Defaults to MILVUS_PASSWORD.
Password string
// Token for authentication (alternative to username/password).
// Defaults to MILVUS_TOKEN.
Token string
client client.Client // Milvus client.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
// Name returns the plugin name.
func (m *Milvus) Name() string {
return provider
return provider
}
// Init initializes the Milvus plugin.
func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if m == nil {
m = &Milvus{}
}
m.mu.Lock()
defer m.mu.Unlock()
defer func() {
if err != nil {
err = fmt.Errorf("milvus.Init: %w", err)
}
}()
if m.initted {
return errors.New("plugin already initialized")
}
// Load configuration.
addr := m.Addr
if addr == "" {
addr = os.Getenv("MILVUS_ADDRESS")
}
if addr == "" {
return errors.New("milvus address required")
}
username := m.Username
if username == "" {
username = os.Getenv("MILVUS_USERNAME")
}
password := m.Password
if password == "" {
password = os.Getenv("MILVUS_PASSWORD")
}
token := m.Token
if token == "" {
token = os.Getenv("MILVUS_TOKEN")
}
// Initialize Milvus client.
config := client.Config{
Address: addr,
Username: username,
Password: password,
APIKey: token,
}
client, err := client.NewClient(ctx, config)
if err != nil {
return fmt.Errorf("failed to initialize Milvus client: %v", err)
}
m.client = client
m.initted = true
return nil
if m == nil {
m = &Milvus{}
}
m.mu.Lock()
defer m.mu.Unlock()
defer func() {
if err != nil {
err = fmt.Errorf("milvus.Init: %w", err)
}
}()
if m.initted {
return errors.New("plugin already initialized")
}
// Load configuration.
addr := m.Addr
if addr == "" {
addr = os.Getenv("MILVUS_ADDRESS")
}
if addr == "" {
return errors.New("milvus address required")
}
username := m.Username
if username == "" {
username = os.Getenv("MILVUS_USERNAME")
}
password := m.Password
if password == "" {
password = os.Getenv("MILVUS_PASSWORD")
}
token := m.Token
if token == "" {
token = os.Getenv("MILVUS_TOKEN")
}
// Initialize Milvus client.
config := client.Config{
Address: addr,
Username: username,
Password: password,
APIKey: token,
}
client, err := client.NewClient(ctx, config)
if err != nil {
return fmt.Errorf("failed to initialize Milvus client: %v", err)
}
m.client = client
m.initted = true
return nil
}
// CollectionConfig holds configuration for an indexer/retriever pair.
type CollectionConfig struct {
// Milvus collection name. Must not be empty.
Collection string
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
Dimension int
// Embedder for generating vectors.
Embedder ai.Embedder
// Embedder options.
EmbedderOptions any
// Milvus collection name. Must not be empty.
Collection string
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
Dimension int
// Embedder for generating vectors.
Embedder ai.Embedder
// Embedder options.
EmbedderOptions any
}
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
if cfg.Embedder == nil {
return nil, nil, errors.New("milvus: Embedder required")
}
if cfg.Collection == "" {
return nil, nil, errors.New("milvus: collection name required")
}
if cfg.Dimension <= 0 {
return nil, nil, errors.New("milvus: dimension must be positive")
}
m := genkit.LookupPlugin(g, provider)
if m == nil {
return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
}
milvus := m.(*Milvus)
ds, err := milvus.newDocStore(ctx, &cfg)
if err != nil {
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
return indexer, retriever, nil
if cfg.Embedder == nil {
return nil, nil, errors.New("milvus: Embedder required")
}
if cfg.Collection == "" {
return nil, nil, errors.New("milvus: collection name required")
}
if cfg.Dimension <= 0 {
return nil, nil, errors.New("milvus: dimension must be positive")
}
m := genkit.LookupPlugin(g, provider)
if m == nil {
return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
}
milvus := m.(*Milvus)
ds, err := milvus.newDocStore(ctx, &cfg)
if err != nil {
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
return indexer, retriever, nil
}
// docStore defines an Indexer and a Retriever.
type docStore struct {
client client.Client
collection string
dimension int
embedder ai.Embedder
embedderOptions map[string]interface{}
client client.Client
collection string
dimension int
embedder ai.Embedder
embedderOptions map[string]interface{}
}
// newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
if m.client == nil {
return nil, errors.New("milvus.Init not called")
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: idField,
DataType: entity.FieldTypeInt64,
PrimaryKey: true,
AutoID: true,
},
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
if m.client == nil {
return nil, errors.New("milvus.Init not called")
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: idField,
DataType: entity.FieldTypeInt64,
PrimaryKey: true,
AutoID: true,
},
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
}
// Indexer returns the indexer for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
return genkit.LookupIndexer(g, provider, collection)
return genkit.LookupIndexer(g, provider, collection)
}
// Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection)
return genkit.LookupRetriever(g, provider, collection)
}
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 {
return nil
}
// Embed documents.
ereq := &ai.EmbedRequest{
Input: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
doc := req.Documents[i]
var sb strings.Builder
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = emb.Embedding // []float32
row["text"] = text
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
rows = append(rows, row)
// Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata)
}
// Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
return fmt.Errorf("milvus insert rows failed: %w", err)
}
return nil
if len(req.Documents) == 0 {
return nil
}
// Embed documents.
ereq := &ai.EmbedRequest{
Input: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
doc := req.Documents[i]
var sb strings.Builder
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = emb.Embedding // []float32
row["text"] = text
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
rows = append(rows, row)
// Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata)
}
// Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
return fmt.Errorf("milvus insert rows failed: %w", err)
}
return nil
}
// RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct {
Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Perform search.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
"", // expr
[]string{textField, metadataField}, // output fields
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process results.
var docs []*ai.Document
for _, result := range results {
for i := 0; i < result.ResultCount; i++ {
textCol := result.Fields.GetColumn(textField)
text, err := textCol.GetAsString(i)
if err != nil {
continue
}
var metadata map[string]interface{}
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Perform search.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
"", // expr
[]string{textField, metadataField}, // output fields
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process results.
var docs []*ai.Document
for _, result := range results {
for i := 0; i < result.ResultCount; i++ {
textCol := result.Fields.GetColumn(textField)
text, err := textCol.GetAsString(i)
if err != nil {
continue
}
var metadata map[string]interface{}
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
......@@ -33,156 +33,155 @@ import (
type MockEmbedder struct{}
func (m *MockEmbedder) Name() string {
return "mock-embedder"
return "mock-embedder"
}
func (m *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
resp := &ai.EmbedResponse{}
for range req.Input {
// Generate a simple embedding (768-dimensional vector of ones)
embedding := make([]float32, 768)
for i := range embedding {
embedding[i] = 1.0
}
resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding})
}
return resp, nil
resp := &ai.EmbedResponse{}
for range req.Input {
// Generate a simple embedding (768-dimensional vector of ones)
embedding := make([]float32, 768)
for i := range embedding {
embedding[i] = 1.0
}
resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding})
}
return resp, nil
}
// dropCollection cleans up a test collection.
func dropCollection(ctx context.Context, client client.Client, collectionName string) error {
exists, err := client.HasCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("check collection: %w", err)
}
if exists {
err = client.DropCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("drop collection: %w", err)
}
}
return nil
exists, err := client.HasCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("check collection: %w", err)
}
if exists {
err = client.DropCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("drop collection: %w", err)
}
}
return nil
}
func TestMilvusIntegration(t *testing.T) {
ctx := context.Background()
// Initialize Milvus plugin
ms := Milvus{
Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
}
// Initialize Genkit with Milvus plugin
g, err := genkit.Init(ctx, genkit.WithPlugins(&ms))
if err != nil {
t.Fatalf("genkit.Init failed: %v", err)
}
// Get the Milvus client for cleanup
m, ok := genkit.LookupPlugin(g, provider).(*Milvus)
if !ok {
t.Fatalf("Failed to lookup Milvus plugin")
}
defer m.client.Close()
// Generate unique collection name
collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano())
// Configure collection
cfg := CollectionConfig{
Collection: collectionName,
Dimension: 768, // Match mock embedder dimension
Embedder: &MockEmbedder{},
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
}
// Define indexer and retriever
indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil {
t.Fatalf("DefineIndexerAndRetriever failed: %v", err)
}
// Clean up collection after test
// defer func() {
// if err := dropCollection(ctx, m.client, collectionName); err != nil {
// t.Errorf("Cleanup failed: %v", err)
// }
// }()
t.Run("Index and Retrieve", func(t *testing.T) {
// Index documents
documents := []*ai.Document{
{
Content: []*ai.Part{ai.NewTextPart("Hello world")},
Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"},
},
{
Content: []*ai.Part{ai.NewTextPart("AI is amazing")},
Metadata: map[string]interface{}{"id": int64(2), "category": "tech"},
},
}
req := &ai.IndexerRequest{Documents: documents}
err := indexer.Index(ctx, req)
if err != nil {
t.Fatalf("Index failed: %v", err)
}
// Wait briefly to ensure Milvus processes the index
time.Sleep(1 * time.Second)
// Retrieve documents
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{
Count: 2,
MetricType: "L2",
},
}
resp, err := retriever.Retrieve(ctx, queryReq)
if err != nil {
t.Fatalf("Retrieve failed: %v", err)
}
// Verify results
assert.NotNil(t, resp, "Response should not be nil")
assert.NotEmpty(t, resp.Documents, "Should return at least one document")
for _, doc := range resp.Documents {
assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty")
// Note: Mock embedder returns identical vectors, so results may not be exact
if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") {
continue
}
t.Errorf("Unexpected document text: %s", doc.Content[0].Text)
}
})
t.Run("Empty Index", func(t *testing.T) {
req := &ai.IndexerRequest{Documents: []*ai.Document{}}
err := indexer.Index(ctx, req)
assert.NoError(t, err, "Indexing empty documents should succeed")
})
t.Run("Invalid Retrieve Options", func(t *testing.T) {
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{MetricType: "INVALID"},
}
_, err := retriever.Retrieve(ctx, queryReq)
assert.Error(t, err, "Should fail with invalid metric type")
assert.Contains(t, err.Error(), "unsupported metric type")
})
t.Run("Invalid Embedder Options", func(t *testing.T) {
// Test with invalid EmbedderOptions type
invalidCfg := CollectionConfig{
Collection: collectionName + "_invalid",
Dimension: 768,
Embedder: &MockEmbedder{},
EmbedderOptions: "not-a-map", // Invalid type
}
_, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg)
assert.Error(t, err, "Should fail with invalid EmbedderOptions type")
assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
})
ctx := context.Background()
// Initialize Milvus plugin
ms := Milvus{
Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
}
// Initialize Genkit with Milvus plugin
g, err := genkit.Init(ctx, genkit.WithPlugins(&ms))
if err != nil {
t.Fatalf("genkit.Init failed: %v", err)
}
// Get the Milvus client for cleanup
m, ok := genkit.LookupPlugin(g, provider).(*Milvus)
if !ok {
t.Fatalf("Failed to lookup Milvus plugin")
}
defer m.client.Close()
// Generate unique collection name
collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano())
// Configure collection
cfg := CollectionConfig{
Collection: collectionName,
Dimension: 768, // Match mock embedder dimension
Embedder: &MockEmbedder{},
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
}
// Define indexer and retriever
indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil {
t.Fatalf("DefineIndexerAndRetriever failed: %v", err)
}
// Clean up collection after test
// defer func() {
// if err := dropCollection(ctx, m.client, collectionName); err != nil {
// t.Errorf("Cleanup failed: %v", err)
// }
// }()
t.Run("Index and Retrieve", func(t *testing.T) {
// Index documents
documents := []*ai.Document{
{
Content: []*ai.Part{ai.NewTextPart("Hello world")},
Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"},
},
{
Content: []*ai.Part{ai.NewTextPart("AI is amazing")},
Metadata: map[string]interface{}{"id": int64(2), "category": "tech"},
},
}
req := &ai.IndexerRequest{Documents: documents}
err := indexer.Index(ctx, req)
if err != nil {
t.Fatalf("Index failed: %v", err)
}
// Wait briefly to ensure Milvus processes the index
time.Sleep(1 * time.Second)
// Retrieve documents
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{
Count: 2,
MetricType: "L2",
},
}
resp, err := retriever.Retrieve(ctx, queryReq)
if err != nil {
t.Fatalf("Retrieve failed: %v", err)
}
// Verify results
assert.NotNil(t, resp, "Response should not be nil")
assert.NotEmpty(t, resp.Documents, "Should return at least one document")
for _, doc := range resp.Documents {
assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty")
// Note: Mock embedder returns identical vectors, so results may not be exact
if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") {
continue
}
t.Errorf("Unexpected document text: %s", doc.Content[0].Text)
}
})
t.Run("Empty Index", func(t *testing.T) {
req := &ai.IndexerRequest{Documents: []*ai.Document{}}
err := indexer.Index(ctx, req)
assert.NoError(t, err, "Indexing empty documents should succeed")
})
t.Run("Invalid Retrieve Options", func(t *testing.T) {
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{MetricType: "INVALID"},
}
_, err := retriever.Retrieve(ctx, queryReq)
assert.Error(t, err, "Should fail with invalid metric type")
assert.Contains(t, err.Error(), "unsupported metric type")
})
t.Run("Invalid Embedder Options", func(t *testing.T) {
// Test with invalid EmbedderOptions type
invalidCfg := CollectionConfig{
Collection: collectionName + "_invalid",
Dimension: 768,
Embedder: &MockEmbedder{},
EmbedderOptions: "not-a-map", // Invalid type
}
_, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg)
assert.Error(t, err, "Should fail with invalid EmbedderOptions type")
assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
})
}
......@@ -12,188 +12,188 @@ import (
)
var (
connString = flag.String("dbconn", "", "database connection string")
connString = flag.String("dbconn", "", "database connection string")
)
// QA 结构体表示 qa 表的记录
type QA struct {
ID int64 // 主键
CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名
Question *string // 可空的问题
Answer *string // 可空的答案
ID int64 // 主键
CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名
Question *string // 可空的问题
Answer *string // 可空的答案
}
// QAStore 定义 DAO 接口
type QAStore interface {
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error)
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error)
}
// qaStore 是 QAStore 接口的实现
type qaStore struct {
db *sql.DB
db *sql.DB
}
// NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db}
return &qaStore{db: db}
}
// GetLatestQA 从 latest_qa 视图读取数据
func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) {
query := `
query := `
SELECT id, created_at, user_id, username, question, answer
FROM latest_qa
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
args := []interface{}{userID}
if userID == nil {
args = []interface{}{nil}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err)
}
defer rows.Close()
var results []QA
for rows.Next() {
var qa QA
var userIDVal sql.NullInt64
var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
if userIDVal.Valid {
qa.UserID = &userIDVal.Int64
}
if username.Valid {
qa.Username = &username.String
}
if question.Valid {
qa.Question = &question.String
}
if answer.Valid {
qa.Answer = &answer.String
}
results = append(results, qa)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err)
}
return results, nil
args := []interface{}{userID}
if userID == nil {
args = []interface{}{nil}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err)
}
defer rows.Close()
var results []QA
for rows.Next() {
var qa QA
var userIDVal sql.NullInt64
var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
if userIDVal.Valid {
qa.UserID = &userIDVal.Int64
}
if username.Valid {
qa.Username = &username.String
}
if question.Valid {
qa.Question = &question.String
}
if answer.Valid {
qa.Answer = &answer.String
}
results = append(results, qa)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err)
}
return results, nil
}
// WriteQA 插入或更新 qa 表记录
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 {
// 更新记录
query := `
if qa.ID != 0 {
// 更新记录
query := `
UPDATE qa
SET user_id = $1, username = $2, question = $3, answer = $4
WHERE id = $5
RETURNING id`
var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID)
}
if err != nil {
return 0, fmt.Errorf("update qa: %w", err)
}
return updatedID, nil
}
// 插入新记录
query := `
var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID)
}
if err != nil {
return 0, fmt.Errorf("update qa: %w", err)
}
return updatedID, nil
}
// 插入新记录
query := `
INSERT INTO qa (user_id, username, question, answer)
VALUES ($1, $2, $3, $4)
RETURNING id`
var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil {
return 0, fmt.Errorf("insert qa: %w", err)
}
return newID, nil
var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil {
return 0, fmt.Errorf("insert qa: %w", err)
}
return newID, nil
}
func mainQA() {
flag.Parse()
ctx := context.Background()
if *connString == "" {
log.Fatal("need -dbconn")
}
db, err := sql.Open("postgres", *connString)
if err != nil {
log.Fatalf("open database: %v", err)
}
defer db.Close()
store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil {
log.Fatalf("get latest QA: %v", err)
}
for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
}
// 示例:插入新 QA
newQA := QA{
UserID: int64Ptr(101),
Username: stringPtr("alice"),
Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."),
}
newID, err := store.WriteQA(ctx, newQA)
if err != nil {
log.Fatalf("write QA: %v", err)
}
fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA
updateQA := QA{
ID: newID,
UserID: int64Ptr(101),
Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."),
}
updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil {
log.Fatalf("update QA: %v", err)
}
fmt.Printf("Updated QA with ID: %d\n", updatedID)
flag.Parse()
ctx := context.Background()
if *connString == "" {
log.Fatal("need -dbconn")
}
db, err := sql.Open("postgres", *connString)
if err != nil {
log.Fatalf("open database: %v", err)
}
defer db.Close()
store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil {
log.Fatalf("get latest QA: %v", err)
}
for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
}
// 示例:插入新 QA
newQA := QA{
UserID: int64Ptr(101),
Username: stringPtr("alice"),
Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."),
}
newID, err := store.WriteQA(ctx, newQA)
if err != nil {
log.Fatalf("write QA: %v", err)
}
fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA
updateQA := QA{
ID: newID,
UserID: int64Ptr(101),
Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."),
}
updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil {
log.Fatalf("update QA: %v", err)
}
fmt.Printf("Updated QA with ID: %d\n", updatedID)
}
// 辅助函数:处理指针类型的空值
func int64Ptr(i int64) *int64 {
return &i
return &i
}
func stringPtr(s string) *string {
return &s
return &s
}
func derefInt64(p *int64) interface{} {
if p == nil {
return nil
}
return *p
if p == nil {
return nil
}
return *p
}
func derefString(p *string) interface{} {
if p == nil {
return nil
}
return *p
}
\ No newline at end of file
if p == nil {
return nil
}
return *p
}
package main
import (
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
limiter *rate.Limiter
queue chan struct{}
maxWorkers int
mu sync.Mutex
}
// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(ratePerSecond), burst),
queue: make(chan struct{}, maxWorkers),
maxWorkers: maxWorkers,
}
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.limiter.Wait(ctx); err != nil {
return false
}
select {
case rl.queue <- struct{}{}:
return true
default:
return false
}
}
// Release 释放并发槽
func (rl *RateLimiter) Release() {
<-rl.queue
}
// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !rl.Allow(ctx) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
defer rl.Release()
next.ServeHTTP(w, r)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment