Commit c3342c58 authored by Wade's avatar Wade

add rate limit

parent b28247b8
...@@ -80,6 +80,7 @@ require ( ...@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect google.golang.org/grpc v1.72.0 // indirect
......
...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= ...@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
......
...@@ -4,11 +4,13 @@ import ( ...@@ -4,11 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net/http"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
) )
func main() { func main() {
...@@ -16,7 +18,7 @@ func main() { ...@@ -16,7 +18,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ds := deepseek.DeepSeek{ ds := deepseek.DeepSeek{
APIKey:"sk-9f70df871a7c4b8aa566a3c7a0603706", APIKey: "sk-9f70df871a7c4b8aa566a3c7a0603706",
} }
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds)) g, err := genkit.Init(ctx, genkit.WithPlugins(&ds))
...@@ -24,38 +26,48 @@ func main() { ...@@ -24,38 +26,48 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
m :=ds.DefineModel(g, m := ds.DefineModel(g,
deepseek.ModelDefinition{ deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support Type: "chat", // Must be chat for tool support
}, },
nil) nil)
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
//genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`)) ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil{ if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
return return "", err
} }
fmt.Println("resp.Text()",resp.Text())
// if err != nil { fmt.Println("resp.Text()", resp.Text())
// return "", err
// }
// text := resp.Text() if err != nil {
// return text, nil return "", err
// }) }
//<-ctx.Done() text := resp.Text()
} return text, nil
})
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl := NewRateLimiter(10, 20, 5)
// 创建 Genkit HTTP 处理器
mux := http.NewServeMux()
for _, a := range genkit.ListFlows(g) {
handler := rl.Middleware(genkit.Handler(a))
mux.Handle("POST /"+a.Name(), handler)
}
// 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
const provider = "deepseek" const provider = "deepseek"
var ( var (
mediaSupportedModels = []string{deepseek.DeepSeekChat,deepseek.DeepSeekCoder,deepseek.DeepSeekReasoner} mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{ // toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
...@@ -34,154 +34,148 @@ var ( ...@@ -34,154 +34,148 @@ var (
} }
) )
// DeepSeek holds configuration for the plugin. // DeepSeek holds configuration for the plugin.
type DeepSeek struct { type DeepSeek struct {
APIKey string // DeepSeek API key APIKey string // DeepSeek API key
//ServerAddress string //ServerAddress string
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the provider name. // Name returns the provider name.
func (d DeepSeek) Name() string { func (d DeepSeek) Name() string {
return provider return provider
} }
// ModelDefinition represents a model with its name and type. // ModelDefinition represents a model with its name and type.
type ModelDefinition struct { type ModelDefinition struct {
Name string Name string
Type string Type string
} }
// // DefineModel defines a DeepSeek model in Genkit. // // DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if !d.initted { if !d.initted {
panic("deepseek.Init not called") panic("deepseek.Init not called")
} }
// Define model info, supporting multiturn and system role. // Define model info, supporting multiturn and system role.
mi := ai.ModelInfo{ mi := ai.ModelInfo{
Label: model.Name, Label: model.Name,
Supports: &ai.ModelSupports{ Supports: &ai.ModelSupports{
Multiturn: true, Multiturn: true,
SystemRole: true, SystemRole: true,
Media: false, // DeepSeek API primarily supports text. Media: false, // DeepSeek API primarily supports text.
Tools: false, // Tools not yet supported in this implementation. Tools: false, // Tools not yet supported in this implementation.
}, },
Versions: []string{}, Versions: []string{},
} }
if info != nil { if info != nil {
mi = *info mi = *info
} }
meta := &ai.ModelInfo{ meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name, // Label: "DeepSeek - " + model.Name,
Label: model.Name, Label: model.Name,
Supports: mi.Supports, Supports: mi.Supports,
Versions: []string{}, Versions: []string{},
} }
gen := &generator{model: model, apiKey: d.APIKey} gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
} }
// Init initializes the DeepSeek plugin. // Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error { func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.initted { if d.initted {
panic("deepseek.Init already called") panic("deepseek.Init already called")
} }
if d == nil || d.APIKey == "" { if d == nil || d.APIKey == "" {
return fmt.Errorf("deepseek: need APIKey") return fmt.Errorf("deepseek: need APIKey")
} }
d.initted = true d.initted = true
return nil return nil
} }
// generator handles model generation. // generator handles model generation.
type generator struct { type generator struct {
model ModelDefinition model ModelDefinition
apiKey string apiKey string
} }
// generate implements the Genkit model generation interface. // generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// stream := cb != nil // stream := cb != nil
if len(input.Messages) == 0 { if len(input.Messages) == 0 {
return nil, fmt.Errorf("prompt or messages required") return nil, fmt.Errorf("prompt or messages required")
} }
// Set up the Deepseek client // Set up the Deepseek client
// Initialize DeepSeek client. // Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey) client := deepseek.NewClient(g.apiKey)
// Create a chat completion request // Create a chat completion request
request := &deepseek.ChatCompletionRequest{ request := &deepseek.ChatCompletionRequest{
Model: g.model.Name, Model: g.model.Name,
} }
for _, msg := range input.Messages { for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role] role, ok := roleMapping[msg.Role]
if !ok { if !ok {
return nil, fmt.Errorf("unsupported role: %s", msg.Role) return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
} }
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
}
// Send the request and handle the response // Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request) response, err := client.CreateChatCompletion(ctx, request)
if err != nil { if err != nil {
log.Fatalf("error: %v", err) log.Fatalf("error: %v", err)
} }
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
for _, chunk := range response.Choices { // Print the response
p := ai.Part{ fmt.Println("Response:", response.Choices[0].Message.Content)
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index), // Create a final response with the merged chunks
} finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
finalResponse.Message.Content = append(finalResponse.Message.Content,&p) for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
} }
return finalResponse, nil // Return the final merged response
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
} }
// concatMessageParts concatenates message parts into a single string. // concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string { func concatMessageParts(parts []*ai.Part) string {
var sb strings.Builder var sb strings.Builder
for _, part := range parts { for _, part := range parts {
if part.IsText() { if part.IsText() {
sb.WriteString(part.Text) sb.WriteString(part.Text)
} }
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them. // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
} }
return sb.String() return sb.String()
} }
/* /*
// Choice represents a completion choice generated by the model. // Choice represents a completion choice generated by the model.
...@@ -205,5 +199,3 @@ type Part struct { ...@@ -205,5 +199,3 @@ type Part struct {
} }
*/ */
...@@ -71,129 +71,129 @@ import ( ...@@ -71,129 +71,129 @@ import (
// Client 知识库客户端 // Client 知识库客户端
type Client struct { type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670" BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
} }
// SpaceRequest 创建空间的请求结构体 // SpaceRequest 创建空间的请求结构体
type SpaceRequest struct { type SpaceRequest struct {
ID int `json:"id"` ID int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
VectorType string `json:"vector_type"` VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"` DomainType string `json:"domain_type"`
Desc string `json:"desc"` Desc string `json:"desc"`
Owner string `json:"owner"` Owner string `json:"owner"`
SpaceID int `json:"space_id"` SpaceID int `json:"space_id"`
} }
// DocumentRequest 添加文档的请求结构体 // DocumentRequest 添加文档的请求结构体
type DocumentRequest struct { type DocumentRequest struct {
DocName string `json:"doc_name"` DocName string `json:"doc_name"`
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
DocType string `json:"doc_type"` DocType string `json:"doc_type"`
DocToken string `json:"doc_token"` DocToken string `json:"doc_token"`
Content string `json:"content"` Content string `json:"content"`
Source string `json:"source"` Source string `json:"source"`
Labels string `json:"labels"` Labels string `json:"labels"`
Questions []string `json:"questions"` Questions []string `json:"questions"`
} }
// ChunkParameters 分片参数 // ChunkParameters 分片参数
type ChunkParameters struct { type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"` ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"` TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"` SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"` ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"` ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"` Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"` EnableMerge bool `json:"enable_merge"`
} }
// SyncBatchRequest 同步批处理的请求结构体 // SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct { type SyncBatchRequest struct {
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
SpaceID string `json:"space_id"` SpaceID string `json:"space_id"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"` ChunkParameters ChunkParameters `json:"chunk_parameters"`
} }
// NewClient 创建新的客户端实例 // NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client { func NewClient(ip string, port int) *Client {
return &Client{ return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port), BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
} }
} }
// AddSpace 创建知识空间 // AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) { func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL) url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
// AddDocument 添加文档 // AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) { func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID) url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
// SyncBatchDocument 同步批处理文档 // SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) { func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID) url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Accept", "application/json") httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
return resp, nil return resp, nil
} }
...@@ -36,393 +36,391 @@ const provider = "milvus" ...@@ -36,393 +36,391 @@ const provider = "milvus"
// Field names for Milvus schema. // Field names for Milvus schema.
const ( const (
idField = "id" idField = "id"
vectorField = "vector" vectorField = "vector"
textField = "text" textField = "text"
metadataField = "metadata" metadataField = "metadata"
) )
// Milvus holds configuration for the plugin. // Milvus holds configuration for the plugin.
type Milvus struct { type Milvus struct {
// Milvus server address (host:port, e.g., "localhost:19530"). // Milvus server address (host:port, e.g., "localhost:19530").
// Defaults to MILVUS_ADDRESS environment variable. // Defaults to MILVUS_ADDRESS environment variable.
Addr string Addr string
// Username for authentication. // Username for authentication.
// Defaults to MILVUS_USERNAME. // Defaults to MILVUS_USERNAME.
Username string Username string
// Password for authentication. // Password for authentication.
// Defaults to MILVUS_PASSWORD. // Defaults to MILVUS_PASSWORD.
Password string Password string
// Token for authentication (alternative to username/password). // Token for authentication (alternative to username/password).
// Defaults to MILVUS_TOKEN. // Defaults to MILVUS_TOKEN.
Token string Token string
client client.Client // Milvus client. client client.Client // Milvus client.
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the plugin name. // Name returns the plugin name.
func (m *Milvus) Name() string { func (m *Milvus) Name() string {
return provider return provider
} }
// Init initializes the Milvus plugin. // Init initializes the Milvus plugin.
func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) { func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if m == nil { if m == nil {
m = &Milvus{} m = &Milvus{}
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
defer func() { defer func() {
if err != nil { if err != nil {
err = fmt.Errorf("milvus.Init: %w", err) err = fmt.Errorf("milvus.Init: %w", err)
} }
}() }()
if m.initted { if m.initted {
return errors.New("plugin already initialized") return errors.New("plugin already initialized")
} }
// Load configuration. // Load configuration.
addr := m.Addr addr := m.Addr
if addr == "" { if addr == "" {
addr = os.Getenv("MILVUS_ADDRESS") addr = os.Getenv("MILVUS_ADDRESS")
} }
if addr == "" { if addr == "" {
return errors.New("milvus address required") return errors.New("milvus address required")
} }
username := m.Username username := m.Username
if username == "" { if username == "" {
username = os.Getenv("MILVUS_USERNAME") username = os.Getenv("MILVUS_USERNAME")
} }
password := m.Password password := m.Password
if password == "" { if password == "" {
password = os.Getenv("MILVUS_PASSWORD") password = os.Getenv("MILVUS_PASSWORD")
} }
token := m.Token token := m.Token
if token == "" { if token == "" {
token = os.Getenv("MILVUS_TOKEN") token = os.Getenv("MILVUS_TOKEN")
} }
// Initialize Milvus client. // Initialize Milvus client.
config := client.Config{ config := client.Config{
Address: addr, Address: addr,
Username: username, Username: username,
Password: password, Password: password,
APIKey: token, APIKey: token,
} }
client, err := client.NewClient(ctx, config) client, err := client.NewClient(ctx, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize Milvus client: %v", err) return fmt.Errorf("failed to initialize Milvus client: %v", err)
} }
m.client = client m.client = client
m.initted = true m.initted = true
return nil return nil
} }
// CollectionConfig holds configuration for an indexer/retriever pair. // CollectionConfig holds configuration for an indexer/retriever pair.
type CollectionConfig struct { type CollectionConfig struct {
// Milvus collection name. Must not be empty. // Milvus collection name. Must not be empty.
Collection string Collection string
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002). // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
Dimension int Dimension int
// Embedder for generating vectors. // Embedder for generating vectors.
Embedder ai.Embedder Embedder ai.Embedder
// Embedder options. // Embedder options.
EmbedderOptions any EmbedderOptions any
} }
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection. // DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) { func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
if cfg.Embedder == nil { if cfg.Embedder == nil {
return nil, nil, errors.New("milvus: Embedder required") return nil, nil, errors.New("milvus: Embedder required")
} }
if cfg.Collection == "" { if cfg.Collection == "" {
return nil, nil, errors.New("milvus: collection name required") return nil, nil, errors.New("milvus: collection name required")
} }
if cfg.Dimension <= 0 { if cfg.Dimension <= 0 {
return nil, nil, errors.New("milvus: dimension must be positive") return nil, nil, errors.New("milvus: dimension must be positive")
} }
m := genkit.LookupPlugin(g, provider) m := genkit.LookupPlugin(g, provider)
if m == nil { if m == nil {
return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?") return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
} }
milvus := m.(*Milvus) milvus := m.(*Milvus)
ds, err := milvus.newDocStore(ctx, &cfg) ds, err := milvus.newDocStore(ctx, &cfg)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index) indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve) retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
return indexer, retriever, nil return indexer, retriever, nil
} }
// docStore defines an Indexer and a Retriever. // docStore defines an Indexer and a Retriever.
type docStore struct { type docStore struct {
client client.Client client client.Client
collection string collection string
dimension int dimension int
embedder ai.Embedder embedder ai.Embedder
embedderOptions map[string]interface{} embedderOptions map[string]interface{}
} }
// newDocStore creates a docStore. // newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
if m.client == nil { if m.client == nil {
return nil, errors.New("milvus.Init not called") return nil, errors.New("milvus.Init not called")
} }
// Check/create collection. // Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection) exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err) return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
} }
if !exists { if !exists {
// Define schema. // Define schema.
schema := &entity.Schema{ schema := &entity.Schema{
CollectionName: cfg.Collection, CollectionName: cfg.Collection,
Fields: []*entity.Field{ Fields: []*entity.Field{
{ {
Name: idField, Name: idField,
DataType: entity.FieldTypeInt64, DataType: entity.FieldTypeInt64,
PrimaryKey: true, PrimaryKey: true,
AutoID: true, AutoID: true,
}, },
{ {
Name: vectorField, Name: vectorField,
DataType: entity.FieldTypeFloatVector, DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{ TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension), "dim": fmt.Sprintf("%d", cfg.Dimension),
}, },
}, },
{ {
Name: textField, Name: textField,
DataType: entity.FieldTypeVarChar, DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{ TypeParams: map[string]string{
"max_length": "65535", "max_length": "65535",
}, },
}, },
{ {
Name: metadataField, Name: metadataField,
DataType: entity.FieldTypeJSON, DataType: entity.FieldTypeJSON,
}, },
}, },
} }
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err) return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
} }
// Create HNSW index. // Create HNSW index.
index, err := entity.NewIndexHNSW( index, err := entity.NewIndexHNSW(
entity.L2, entity.L2,
8, // M 8, // M
96, // efConstruction 96, // efConstruction
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err) return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
} }
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false) err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err) return nil, fmt.Errorf("failed to create index: %v", err)
} }
} }
// Load collection. // Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false) err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err) return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
} }
// Convert EmbedderOptions to map[string]interface{}. // Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{} var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil { if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{}) opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions) return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
} }
embedderOptions = opts embedderOptions = opts
} else { } else {
embedderOptions = make(map[string]interface{}) embedderOptions = make(map[string]interface{})
} }
return &docStore{ return &docStore{
client: m.client, client: m.client,
collection: cfg.Collection, collection: cfg.Collection,
dimension: cfg.Dimension, dimension: cfg.Dimension,
embedder: cfg.Embedder, embedder: cfg.Embedder,
embedderOptions: embedderOptions, embedderOptions: embedderOptions,
}, nil }, nil
} }
// Indexer returns the indexer for a collection. // Indexer returns the indexer for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer { func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
return genkit.LookupIndexer(g, provider, collection) return genkit.LookupIndexer(g, provider, collection)
} }
// Retriever returns the retriever for a collection. // Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever { func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection) return genkit.LookupRetriever(g, provider, collection)
} }
// Index implements the Indexer.Index method. // Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 { if len(req.Documents) == 0 {
return nil return nil
} }
// Embed documents. // Embed documents.
ereq := &ai.EmbedRequest{ ereq := &ai.EmbedRequest{
Input: req.Documents, Input: req.Documents,
Options: ds.embedderOptions, Options: ds.embedderOptions,
} }
eres, err := ds.embedder.Embed(ctx, ereq) eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil { if err != nil {
return fmt.Errorf("milvus index embedding failed: %w", err) return fmt.Errorf("milvus index embedding failed: %w", err)
} }
// Validate embedding count matches document count. // Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) { if len(eres.Embeddings) != len(req.Documents) {
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents)) return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
} }
// Prepare row-based data. // Prepare row-based data.
var rows []interface{} var rows []interface{}
for i, emb := range eres.Embeddings { for i, emb := range eres.Embeddings {
doc := req.Documents[i] doc := req.Documents[i]
var sb strings.Builder var sb strings.Builder
for _, p := range doc.Content { for _, p := range doc.Content {
if p.IsText() { if p.IsText() {
sb.WriteString(p.Text) sb.WriteString(p.Text)
} }
} }
text := sb.String() text := sb.String()
metadata := doc.Metadata metadata := doc.Metadata
if metadata == nil { if metadata == nil {
metadata = make(map[string]interface{}) metadata = make(map[string]interface{})
} }
// Create row with explicit metadata field. // Create row with explicit metadata field.
row := make(map[string]interface{}) row := make(map[string]interface{})
row["vector"] = emb.Embedding // []float32 row["vector"] = emb.Embedding // []float32
row["text"] = text row["text"] = text
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
rows = append(rows, row) rows = append(rows, row)
// Debug: Log row contents. // Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata) fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata)
} }
// Debug: Log total rows. // Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection) fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// Insert rows into Milvus. // Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows) _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil { if err != nil {
return fmt.Errorf("milvus insert rows failed: %w", err) return fmt.Errorf("milvus insert rows failed: %w", err)
} }
return nil return nil
} }
// RetrieverOptions for Milvus retrieval. // RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct { type RetrieverOptions struct {
Count int `json:"count,omitempty"` // Max documents to retrieve. Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
} }
// Retrieve implements the Retriever.Retrieve method. // Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
count := 3 // Default. count := 3 // Default.
metricTypeStr := "L2" metricTypeStr := "L2"
if req.Options != nil { if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions) ropt, ok := req.Options.(*RetrieverOptions)
if !ok { if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{}) return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
} }
if ropt.Count > 0 { if ropt.Count > 0 {
count = ropt.Count count = ropt.Count
} }
if ropt.MetricType != "" { if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType metricTypeStr = ropt.MetricType
} }
} }
// Map string metric type to entity.MetricType. // Map string metric type to entity.MetricType.
var metricType entity.MetricType var metricType entity.MetricType
switch metricTypeStr { switch metricTypeStr {
case "L2": case "L2":
metricType = entity.L2 metricType = entity.L2
case "IP": case "IP":
metricType = entity.IP metricType = entity.IP
default: default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr) return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
} }
// Embed query. // Embed query.
ereq := &ai.EmbedRequest{ ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query}, Input: []*ai.Document{req.Query},
Options: ds.embedderOptions, Options: ds.embedderOptions,
} }
eres, err := ds.embedder.Embed(ctx, ereq) eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil { if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err) return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
} }
if len(eres.Embeddings) == 0 { if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query") return nil, errors.New("no embeddings generated for query")
} }
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding) queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters. // Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil { if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err) return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
} }
// Perform search. // Perform search.
results, err := ds.client.Search( results, err := ds.client.Search(
ctx, ctx,
ds.collection, ds.collection,
[]string{}, // partitions []string{}, // partitions
"", // expr "", // expr
[]string{textField, metadataField}, // output fields []string{textField, metadataField}, // output fields
[]entity.Vector{queryVector}, []entity.Vector{queryVector},
vectorField, vectorField,
metricType, metricType,
count, count,
searchParams, searchParams,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err) return nil, fmt.Errorf("milvus search failed: %v", err)
} }
// Process results. // Process results.
var docs []*ai.Document var docs []*ai.Document
for _, result := range results { for _, result := range results {
for i := 0; i < result.ResultCount; i++ { for i := 0; i < result.ResultCount; i++ {
textCol := result.Fields.GetColumn(textField) textCol := result.Fields.GetColumn(textField)
text, err := textCol.GetAsString(i) text, err := textCol.GetAsString(i)
if err != nil { if err != nil {
continue continue
} }
var metadata map[string]interface{} var metadata map[string]interface{}
doc := ai.DocumentFromText(text, metadata) doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc) docs = append(docs, doc)
} }
} }
return &ai.RetrieverResponse{ return &ai.RetrieverResponse{
Documents: docs, Documents: docs,
}, nil }, nil
} }
...@@ -33,156 +33,155 @@ import ( ...@@ -33,156 +33,155 @@ import (
type MockEmbedder struct{} type MockEmbedder struct{}
func (m *MockEmbedder) Name() string { func (m *MockEmbedder) Name() string {
return "mock-embedder" return "mock-embedder"
} }
func (m *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { func (m *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
resp := &ai.EmbedResponse{} resp := &ai.EmbedResponse{}
for range req.Input { for range req.Input {
// Generate a simple embedding (768-dimensional vector of ones) // Generate a simple embedding (768-dimensional vector of ones)
embedding := make([]float32, 768) embedding := make([]float32, 768)
for i := range embedding { for i := range embedding {
embedding[i] = 1.0 embedding[i] = 1.0
} }
resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding}) resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding})
} }
return resp, nil return resp, nil
} }
// dropCollection cleans up a test collection. // dropCollection cleans up a test collection.
func dropCollection(ctx context.Context, client client.Client, collectionName string) error { func dropCollection(ctx context.Context, client client.Client, collectionName string) error {
exists, err := client.HasCollection(ctx, collectionName) exists, err := client.HasCollection(ctx, collectionName)
if err != nil { if err != nil {
return fmt.Errorf("check collection: %w", err) return fmt.Errorf("check collection: %w", err)
} }
if exists { if exists {
err = client.DropCollection(ctx, collectionName) err = client.DropCollection(ctx, collectionName)
if err != nil { if err != nil {
return fmt.Errorf("drop collection: %w", err) return fmt.Errorf("drop collection: %w", err)
} }
} }
return nil return nil
} }
func TestMilvusIntegration(t *testing.T) { func TestMilvusIntegration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Initialize Milvus plugin // Initialize Milvus plugin
ms := Milvus{ ms := Milvus{
Addr: "54.92.111.204:19530", // Milvus gRPC endpoint Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
} }
// Initialize Genkit with Milvus plugin // Initialize Genkit with Milvus plugin
g, err := genkit.Init(ctx, genkit.WithPlugins(&ms)) g, err := genkit.Init(ctx, genkit.WithPlugins(&ms))
if err != nil { if err != nil {
t.Fatalf("genkit.Init failed: %v", err) t.Fatalf("genkit.Init failed: %v", err)
} }
// Get the Milvus client for cleanup // Get the Milvus client for cleanup
m, ok := genkit.LookupPlugin(g, provider).(*Milvus) m, ok := genkit.LookupPlugin(g, provider).(*Milvus)
if !ok { if !ok {
t.Fatalf("Failed to lookup Milvus plugin") t.Fatalf("Failed to lookup Milvus plugin")
} }
defer m.client.Close() defer m.client.Close()
// Generate unique collection name // Generate unique collection name
collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano()) collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano())
// Configure collection // Configure collection
cfg := CollectionConfig{ cfg := CollectionConfig{
Collection: collectionName, Collection: collectionName,
Dimension: 768, // Match mock embedder dimension Dimension: 768, // Match mock embedder dimension
Embedder: &MockEmbedder{}, Embedder: &MockEmbedder{},
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
} }
// Define indexer and retriever // Define indexer and retriever
indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg) indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil { if err != nil {
t.Fatalf("DefineIndexerAndRetriever failed: %v", err) t.Fatalf("DefineIndexerAndRetriever failed: %v", err)
} }
// Clean up collection after test // Clean up collection after test
// defer func() { // defer func() {
// if err := dropCollection(ctx, m.client, collectionName); err != nil { // if err := dropCollection(ctx, m.client, collectionName); err != nil {
// t.Errorf("Cleanup failed: %v", err) // t.Errorf("Cleanup failed: %v", err)
// } // }
// }() // }()
t.Run("Index and Retrieve", func(t *testing.T) { t.Run("Index and Retrieve", func(t *testing.T) {
// Index documents // Index documents
documents := []*ai.Document{ documents := []*ai.Document{
{ {
Content: []*ai.Part{ai.NewTextPart("Hello world")}, Content: []*ai.Part{ai.NewTextPart("Hello world")},
Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"}, Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"},
}, },
{ {
Content: []*ai.Part{ai.NewTextPart("AI is amazing")}, Content: []*ai.Part{ai.NewTextPart("AI is amazing")},
Metadata: map[string]interface{}{"id": int64(2), "category": "tech"}, Metadata: map[string]interface{}{"id": int64(2), "category": "tech"},
}, },
} }
req := &ai.IndexerRequest{Documents: documents} req := &ai.IndexerRequest{Documents: documents}
err := indexer.Index(ctx, req) err := indexer.Index(ctx, req)
if err != nil { if err != nil {
t.Fatalf("Index failed: %v", err) t.Fatalf("Index failed: %v", err)
} }
// Wait briefly to ensure Milvus processes the index // Wait briefly to ensure Milvus processes the index
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Retrieve documents // Retrieve documents
queryReq := &ai.RetrieverRequest{ queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}}, Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{ Options: &RetrieverOptions{
Count: 2, Count: 2,
MetricType: "L2", MetricType: "L2",
}, },
} }
resp, err := retriever.Retrieve(ctx, queryReq) resp, err := retriever.Retrieve(ctx, queryReq)
if err != nil { if err != nil {
t.Fatalf("Retrieve failed: %v", err) t.Fatalf("Retrieve failed: %v", err)
} }
// Verify results // Verify results
assert.NotNil(t, resp, "Response should not be nil") assert.NotNil(t, resp, "Response should not be nil")
assert.NotEmpty(t, resp.Documents, "Should return at least one document") assert.NotEmpty(t, resp.Documents, "Should return at least one document")
for _, doc := range resp.Documents { for _, doc := range resp.Documents {
assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty") assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty")
// Note: Mock embedder returns identical vectors, so results may not be exact // Note: Mock embedder returns identical vectors, so results may not be exact
if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") { if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") {
continue continue
} }
t.Errorf("Unexpected document text: %s", doc.Content[0].Text) t.Errorf("Unexpected document text: %s", doc.Content[0].Text)
} }
}) })
t.Run("Empty Index", func(t *testing.T) { t.Run("Empty Index", func(t *testing.T) {
req := &ai.IndexerRequest{Documents: []*ai.Document{}} req := &ai.IndexerRequest{Documents: []*ai.Document{}}
err := indexer.Index(ctx, req) err := indexer.Index(ctx, req)
assert.NoError(t, err, "Indexing empty documents should succeed") assert.NoError(t, err, "Indexing empty documents should succeed")
}) })
t.Run("Invalid Retrieve Options", func(t *testing.T) { t.Run("Invalid Retrieve Options", func(t *testing.T) {
queryReq := &ai.RetrieverRequest{ queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}}, Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{MetricType: "INVALID"}, Options: &RetrieverOptions{MetricType: "INVALID"},
} }
_, err := retriever.Retrieve(ctx, queryReq) _, err := retriever.Retrieve(ctx, queryReq)
assert.Error(t, err, "Should fail with invalid metric type") assert.Error(t, err, "Should fail with invalid metric type")
assert.Contains(t, err.Error(), "unsupported metric type") assert.Contains(t, err.Error(), "unsupported metric type")
}) })
t.Run("Invalid Embedder Options", func(t *testing.T) { t.Run("Invalid Embedder Options", func(t *testing.T) {
// Test with invalid EmbedderOptions type // Test with invalid EmbedderOptions type
invalidCfg := CollectionConfig{ invalidCfg := CollectionConfig{
Collection: collectionName + "_invalid", Collection: collectionName + "_invalid",
Dimension: 768, Dimension: 768,
Embedder: &MockEmbedder{}, Embedder: &MockEmbedder{},
EmbedderOptions: "not-a-map", // Invalid type EmbedderOptions: "not-a-map", // Invalid type
} }
_, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg) _, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg)
assert.Error(t, err, "Should fail with invalid EmbedderOptions type") assert.Error(t, err, "Should fail with invalid EmbedderOptions type")
assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}") assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
}) })
} }
...@@ -12,188 +12,188 @@ import ( ...@@ -12,188 +12,188 @@ import (
) )
var ( var (
connString = flag.String("dbconn", "", "database connection string") connString = flag.String("dbconn", "", "database connection string")
) )
// QA 结构体表示 qa 表的记录 // QA 结构体表示 qa 表的记录
type QA struct { type QA struct {
ID int64 // 主键 ID int64 // 主键
CreatedAt time.Time // 创建时间 CreatedAt time.Time // 创建时间
UserID *int64 // 可空的用户 ID UserID *int64 // 可空的用户 ID
Username *string // 可空的用户名 Username *string // 可空的用户名
Question *string // 可空的问题 Question *string // 可空的问题
Answer *string // 可空的答案 Answer *string // 可空的答案
} }
// QAStore 定义 DAO 接口 // QAStore 定义 DAO 接口
type QAStore interface { type QAStore interface {
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录 // GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error) WriteQA(ctx context.Context, qa QA) (int64, error)
} }
// qaStore 是 QAStore 接口的实现 // qaStore 是 QAStore 接口的实现
type qaStore struct { type qaStore struct {
db *sql.DB db *sql.DB
} }
// NewQAStore 创建新的 QAStore 实例 // NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore { func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db} return &qaStore{db: db}
} }
// GetLatestQA 从 latest_qa 视图读取数据 // GetLatestQA 从 latest_qa 视图读取数据
func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) { func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) {
query := ` query := `
SELECT id, created_at, user_id, username, question, answer SELECT id, created_at, user_id, username, question, answer
FROM latest_qa FROM latest_qa
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)` WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
args := []interface{}{userID} args := []interface{}{userID}
if userID == nil { if userID == nil {
args = []interface{}{nil} args = []interface{}{nil}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, fmt.Errorf("query latest_qa: %w", err) return nil, fmt.Errorf("query latest_qa: %w", err)
} }
defer rows.Close() defer rows.Close()
var results []QA var results []QA
for rows.Next() { for rows.Next() {
var qa QA var qa QA
var userIDVal sql.NullInt64 var userIDVal sql.NullInt64
var username, question, answer sql.NullString var username, question, answer sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil { if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
return nil, fmt.Errorf("scan row: %w", err) return nil, fmt.Errorf("scan row: %w", err)
} }
if userIDVal.Valid { if userIDVal.Valid {
qa.UserID = &userIDVal.Int64 qa.UserID = &userIDVal.Int64
} }
if username.Valid { if username.Valid {
qa.Username = &username.String qa.Username = &username.String
} }
if question.Valid { if question.Valid {
qa.Question = &question.String qa.Question = &question.String
} }
if answer.Valid { if answer.Valid {
qa.Answer = &answer.String qa.Answer = &answer.String
} }
results = append(results, qa) results = append(results, qa)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err) return nil, fmt.Errorf("row iteration: %w", err)
} }
return results, nil return results, nil
} }
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) { func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 { if qa.ID != 0 {
// 更新记录 // 更新记录
query := ` query := `
UPDATE qa UPDATE qa
SET user_id = $1, username = $2, question = $3, answer = $4 SET user_id = $1, username = $2, question = $3, answer = $4
WHERE id = $5 WHERE id = $5
RETURNING id` RETURNING id`
var updatedID int64 var updatedID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID) err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID) return 0, fmt.Errorf("no record found with id %d", qa.ID)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("update qa: %w", err) return 0, fmt.Errorf("update qa: %w", err)
} }
return updatedID, nil return updatedID, nil
} }
// 插入新记录 // 插入新记录
query := ` query := `
INSERT INTO qa (user_id, username, question, answer) INSERT INTO qa (user_id, username, question, answer)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
RETURNING id` RETURNING id`
var newID int64 var newID int64
err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID) err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
if err != nil { if err != nil {
return 0, fmt.Errorf("insert qa: %w", err) return 0, fmt.Errorf("insert qa: %w", err)
} }
return newID, nil return newID, nil
} }
func mainQA() { func mainQA() {
flag.Parse() flag.Parse()
ctx := context.Background() ctx := context.Background()
if *connString == "" { if *connString == "" {
log.Fatal("need -dbconn") log.Fatal("need -dbconn")
} }
db, err := sql.Open("postgres", *connString) db, err := sql.Open("postgres", *connString)
if err != nil { if err != nil {
log.Fatalf("open database: %v", err) log.Fatalf("open database: %v", err)
} }
defer db.Close() defer db.Close()
store := NewQAStore(db) store := NewQAStore(db)
// 示例:读取 user_id=101 的最新 QA // 示例:读取 user_id=101 的最新 QA
results, err := store.GetLatestQA(ctx, int64Ptr(101)) results, err := store.GetLatestQA(ctx, int64Ptr(101))
if err != nil { if err != nil {
log.Fatalf("get latest QA: %v", err) log.Fatalf("get latest QA: %v", err)
} }
for _, qa := range results { for _, qa := range results {
fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n", fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer)) qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
} }
// 示例:插入新 QA // 示例:插入新 QA
newQA := QA{ newQA := QA{
UserID: int64Ptr(101), UserID: int64Ptr(101),
Username: stringPtr("alice"), Username: stringPtr("alice"),
Question: stringPtr("What is AI?"), Question: stringPtr("What is AI?"),
Answer: stringPtr("AI is..."), Answer: stringPtr("AI is..."),
} }
newID, err := store.WriteQA(ctx, newQA) newID, err := store.WriteQA(ctx, newQA)
if err != nil { if err != nil {
log.Fatalf("write QA: %v", err) log.Fatalf("write QA: %v", err)
} }
fmt.Printf("Inserted QA with ID: %d\n", newID) fmt.Printf("Inserted QA with ID: %d\n", newID)
// 示例:更新 QA // 示例:更新 QA
updateQA := QA{ updateQA := QA{
ID: newID, ID: newID,
UserID: int64Ptr(101), UserID: int64Ptr(101),
Username: stringPtr("alice_updated"), Username: stringPtr("alice_updated"),
Question: stringPtr("What is NLP?"), Question: stringPtr("What is NLP?"),
Answer: stringPtr("NLP is..."), Answer: stringPtr("NLP is..."),
} }
updatedID, err := store.WriteQA(ctx, updateQA) updatedID, err := store.WriteQA(ctx, updateQA)
if err != nil { if err != nil {
log.Fatalf("update QA: %v", err) log.Fatalf("update QA: %v", err)
} }
fmt.Printf("Updated QA with ID: %d\n", updatedID) fmt.Printf("Updated QA with ID: %d\n", updatedID)
} }
// 辅助函数:处理指针类型的空值 // 辅助函数:处理指针类型的空值
func int64Ptr(i int64) *int64 { func int64Ptr(i int64) *int64 {
return &i return &i
} }
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }
func derefInt64(p *int64) interface{} { func derefInt64(p *int64) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
func derefString(p *string) interface{} { func derefString(p *string) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
\ No newline at end of file
package main
import (
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type RateLimiter struct {
limiter *rate.Limiter
queue chan struct{}
maxWorkers int
mu sync.Mutex
}
// NewRateLimiter 初始化限速器
func NewRateLimiter(ratePerSecond float64, burst, maxWorkers int) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(ratePerSecond), burst),
queue: make(chan struct{}, maxWorkers),
maxWorkers: maxWorkers,
}
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(ctx context.Context) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.limiter.Wait(ctx); err != nil {
return false
}
select {
case rl.queue <- struct{}{}:
return true
default:
return false
}
}
// Release 释放并发槽
func (rl *RateLimiter) Release() {
<-rl.queue
}
// Middleware HTTP 中间件
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !rl.Allow(ctx) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
defer rl.Release()
next.ServeHTTP(w, r)
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment