Commit 8604b187 authored by Wade's avatar Wade

plugins use new log lib

parent 16d5d242
{"level":"info","pid":11627,"time":1749036605,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":11627,"time":1749036651,"caller":"/Users/wade/project/wuban/agentchat/main.go:229","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":11627,"time":1749036651,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"qaAsJson--------{\"ID\":16,\"CreatedAt\":\"2025-06-04T09:50:02.837091Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":null,\"Summary\":null,\"To\":\"Bob\",\"ToID\":\"user4567\"}"}
{"level":"info","pid":11627,"time":1749036656,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":11627,"time":1749036665,"caller":"/Users/wade/project/wuban/agentchat/main.go:294","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":12824,"time":1749036924,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":12824,"time":1749036930,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":12824,"time":1749036931,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":27,\"CreatedAt\":\"2025-06-04T11:30:53.508295Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"Well now, if Paris is the heart of France, and the US loves its coffee, then you're probably wondering about the UK. The capital of the UK is London, a truly grand city!\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":12824,"time":1749036933,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":12824,"time":1749036937,"caller":"/Users/wade/project/wuban/agentchat/main.go:320","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":18861,"time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":18861,"method":"DeepSeek.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/deepseek/deepseek.go:91","message":"Initializing DeepSeek plugin"}
{"level":"info","pid":18861,"method":"DeepSeek.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/deepseek/deepseek.go:104","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"Milvus.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:75","message":"Initializing Milvus plugin"}
{"level":"info","pid":18861,"method":"Milvus.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:87","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"GraphKnowledge.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:299","message":"Initializing GraphKnowledge plugin"}
{"level":"info","pid":18861,"method":"NewClient","ip":"54.92.111.204","port":5670,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:93","message":"Creating new GraphRAG client"}
{"level":"info","pid":18861,"method":"GraphKnowledge.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:311","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","collection":"chatRag1","dimension":768,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:152","message":"Defining indexer and retriever"}
{"level":"info","pid":18861,"method":"Milvus.newDocStore","collection":"chatRag1","dimension":768,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:201","message":"Creating new doc store"}
{"level":"info","pid":18861,"method":"Milvus.newDocStore","collection":"chatRag1","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:302","message":"Doc store created successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","collection":"chatRag1","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:182","message":"Indexer and retriever defined successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:357","message":"Defining indexer and retriever"}
{"level":"info","pid":18861,"method":"GraphKnowledge.newDocStore","space_id":"","model_name":"Qwen/Qwen2.5-Coder-32B-Instruct","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:393","message":"Creating new doc store"}
{"level":"info","pid":18861,"method":"GraphKnowledge.newDocStore","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:399","message":"Doc store created successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:376","message":"Indexer and retriever defined successfully"}
{"level":"info","pid":18861,"time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":18861,"time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":28,\"CreatedAt\":\"2025-06-04T11:35:33.142254Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"I'm sorry, but the provided context doesn't contain information about the capital of the UK.\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:450","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","documents":2,"time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:640","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:755","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","documents":1,"time":1749038786,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:892","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038786,"caller":"/Users/wade/project/wuban/agentchat/main.go:320","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":18861,"time":1749038813,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":18861,"time":1749038814,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":29,\"CreatedAt\":\"2025-06-04T12:06:16.535774Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"I'm sorry, but the provided information does not contain the answer to your question about the capital of the UK.\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","time":1749038814,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:450","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","documents":2,"time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:640","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:755","message":"Starting retrieve operation"}
{"level":"fatal","pid":18861,"time":1749038816,"caller":"/Users/wade/project/wuban/agentchat/main.go:362","message":"Server failed: failed to shutdown server: context canceled"}
openapi: 3.0.4 openapi: 3.0.4
info: info:
title: Genkit Chat API title: Genkit Chat API
description: API for interacting with a chat endpoint powered by Genkit. description: API for interacting with chat and indexing endpoints powered by Genkit.
version: 0.1.0 version: 0.1.0
paths: paths:
/index/document: /index/document:
...@@ -44,12 +44,13 @@ paths: ...@@ -44,12 +44,13 @@ paths:
content: content:
application/json: application/json:
schema: schema:
type: object $ref: '#/components/schemas/Response'
properties: examples:
id: success:
type: integer value:
description: The ID of the stored record data: '{"id": 1}'
example: 1 code: 200
msg: "Milvus index data stored successfully"
/index/graph: /index/graph:
post: post:
summary: Store GraphRAG index data summary: Store GraphRAG index data
...@@ -79,7 +80,7 @@ paths: ...@@ -79,7 +80,7 @@ paths:
type: object type: object
description: Additional metadata for the content description: Additional metadata for the content
additionalProperties: true additionalProperties: true
example: example:
source: "user_input" source: "user_input"
timestamp: "2025-06-04T16:54:00+08:00" timestamp: "2025-06-04T16:54:00+08:00"
required: required:
...@@ -90,12 +91,13 @@ paths: ...@@ -90,12 +91,13 @@ paths:
content: content:
application/json: application/json:
schema: schema:
type: object $ref: '#/components/schemas/Response'
properties: examples:
id: success:
type: integer value:
description: The ID of the stored record data: '{"id": 1}'
example: 1 code: 200
msg: "GraphRAG index data stored successfully"
/chat: /chat:
post: post:
summary: Send a chat message summary: Send a chat message
...@@ -145,11 +147,31 @@ paths: ...@@ -145,11 +147,31 @@ paths:
content: content:
application/json: application/json:
schema: schema:
type: object $ref: '#/components/schemas/Response'
properties: examples:
response: success:
type: string value:
description: The response from the chat workflow data: "The capital of the UK is London."
example: "The capital of the UK is London." code: 200
msg: "Chat response generated successfully"
components: components:
schemas: {} schemas:
Response:
type: object
properties:
data:
type: string
description: The response data, typically a JSON string or message
example: '{"id": 1}'
code:
type: integer
description: The response code (200 for success, 400 for invalid input, etc.)
example: 200
msg:
type: string
description: A message describing the result
example: "Milvus index data stored successfully"
required:
- data
- code
- msg
\ No newline at end of file
...@@ -36,8 +36,8 @@ func loggingInit() { ...@@ -36,8 +36,8 @@ func loggingInit() {
// // Configure log rotation with lumberjack // // Configure log rotation with lumberjack
lumberjackLogger := &lumberjack.Logger{ lumberjackLogger := &lumberjack.Logger{
Filename: "/var/log/agent_chat.log", //Filename: "/var/log/agent_chat.log",
//Filename: "./tweet.log", Filename: "agent_chat.log",
MaxSize: 1, // Max size in megabytes before log is rotated MaxSize: 1, // Max size in megabytes before log is rotated
MaxBackups: 3, // Max number of old log files to retain MaxBackups: 3, // Max number of old log files to retain
MaxAge: 28, // Max number of days to retain old log files MaxAge: 28, // Max number of days to retain old log files
......
...@@ -52,8 +52,32 @@ type GraphInput struct { ...@@ -52,8 +52,32 @@ type GraphInput struct {
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
type simpleQaPromptInput struct {
Query string `json:"query"`
Context string `json:"context"`
Graph string `json:"graph"`
Summary string `json:"summary"`
}
// const simpleQaPromptTemplate = `
// You're a helpful agent that answers the user's questions with a tone and style shaped by the specified personality.
// Here is the user's query: {{query}}
// Here is the context you should use: {{context}} from Milvus
// Graph context: {{graph}}
// Previous conversation summary: {{summary}}
// Personality to adopt: {{personality}}
// Please provide a response that aligns with the given personality while leveraging the provided context, graph, and conversation summary.
// `
const simpleQaPromptTemplate = ` const simpleQaPromptTemplate = `
You're a helpful agent that answers the user's questions with a tone and style shaped by the specified personality. You're a helpful agent that answers the user's questions based on the provided context.
Here is the user's query: {{query}} Here is the user's query: {{query}}
...@@ -63,11 +87,13 @@ Graph context: {{graph}} ...@@ -63,11 +87,13 @@ Graph context: {{graph}}
Previous conversation summary: {{summary}} Previous conversation summary: {{summary}}
Personality to adopt: {{personality}} Instructions:
- If the query is related to a character's personality, adopt the tone and style specified in the Personality context, and generate a response using the Milvus and Graph contexts to inform the personality-driven content.
Please provide a response that aligns with the given personality while leveraging the provided context, graph, and conversation summary. - For all other queries, provide a clear and accurate response using the Milvus and Graph contexts, without emphasizing the Personality context.
- Ensure responses leverage the Previous conversation summary when relevant.
` `
func main() { func main() {
debug := flag.Bool("debug", false, "sets log level to debug") debug := flag.Bool("debug", false, "sets log level to debug")
...@@ -220,7 +246,10 @@ func main() { ...@@ -220,7 +246,10 @@ func main() {
inputAsJson, err := json.Marshal(input) inputAsJson, err := json.Marshal(input)
if err != nil { if err != nil {
return "", err return Response{
Code: 500,
Msg: fmt.Sprintf("json.Marshal: %w", err),
}, nil
} }
log.Info().Msgf("input--------%s", string(inputAsJson)) log.Info().Msgf("input--------%s", string(inputAsJson))
...@@ -234,13 +263,19 @@ func main() { ...@@ -234,13 +263,19 @@ func main() {
}) })
if err != nil { if err != nil {
return "", err return Response{
Code: 500,
Msg: fmt.Sprintf("WriteAndGetLatestQA: %w", err),
}, nil
} }
qaAsJson, err := json.Marshal(lastQa) qaAsJson, err := json.Marshal(lastQa)
if err != nil { if err != nil {
return "", err return Response{
Code: 500,
Msg: fmt.Sprintf("json.Marshal(lastQa): %w", err),
}, nil
} }
log.Info().Msgf("qaAsJson--------%s", string(qaAsJson)) log.Info().Msgf("qaAsJson--------%s", string(qaAsJson))
...@@ -299,7 +334,6 @@ func main() { ...@@ -299,7 +334,6 @@ func main() {
return Response{ return Response{
Data: resp.Text(), Data: resp.Text(),
Code: 200, Code: 200,
Msg: fmt.Sprintf("Document indexed successfully, docname %s", resDocName),
}, nil }, nil
}) })
...@@ -329,12 +363,6 @@ func main() { ...@@ -329,12 +363,6 @@ func main() {
} }
} }
type simpleQaPromptInput struct {
Query string `json:"query"`
Context string `json:"context"`
Graph string `json:"graph"`
Summary string `json:"summary"`
}
type Response struct { type Response struct {
Data string `json:"data"` Data string `json:"data"`
......
...@@ -3,12 +3,12 @@ package deepseek ...@@ -3,12 +3,12 @@ package deepseek
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strings" "strings"
"sync" "sync"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/rs/zerolog/log"
deepseek "github.com/cohesion-org/deepseek-go" deepseek "github.com/cohesion-org/deepseek-go"
) )
...@@ -16,166 +16,409 @@ import ( ...@@ -16,166 +16,409 @@ import (
const provider = "deepseek" const provider = "deepseek"
var ( var (
mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner} mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{ roleMapping = map[ai.Role]string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", ai.RoleUser: deepseek.ChatMessageRoleUser,
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", ai.RoleModel: deepseek.ChatMessageRoleAssistant,
// "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus", ai.RoleSystem: deepseek.ChatMessageRoleSystem,
// "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2", ai.RoleTool: deepseek.ChatMessageRoleTool,
// "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe", }
// "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
// "granite3.3", "command-a", "command-r7b-arabic",
// }
roleMapping = map[ai.Role]string{
ai.RoleUser: deepseek.ChatMessageRoleUser,
ai.RoleModel: deepseek.ChatMessageRoleAssistant,
ai.RoleSystem: deepseek.ChatMessageRoleSystem,
ai.RoleTool: deepseek.ChatMessageRoleTool,
}
) )
// DeepSeek holds configuration for the plugin. // DeepSeek holds configuration for the plugin.
type DeepSeek struct { type DeepSeek struct {
APIKey string // DeepSeek API key APIKey string // DeepSeek API key
//ServerAddress string
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the provider name. // Name returns the provider name.
func (d DeepSeek) Name() string { func (d DeepSeek) Name() string {
return provider return provider
} }
// ModelDefinition represents a model with its name and type. // ModelDefinition represents a model with its name and type.
type ModelDefinition struct { type ModelDefinition struct {
Name string Name string
Type string Type string
} }
// // DefineModel defines a DeepSeek model in Genkit. // DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
d.mu.Lock() log.Info().
defer d.mu.Unlock() Str("method", "DeepSeek.DefineModel").
if !d.initted { Str("model_name", model.Name).
panic("deepseek.Init not called") Msg("Defining DeepSeek model")
} d.mu.Lock()
defer d.mu.Unlock()
// Define model info, supporting multiturn and system role. if !d.initted {
mi := ai.ModelInfo{ log.Error().Str("method", "DeepSeek.DefineModel").Msg("DeepSeek not initialized")
Label: model.Name, panic("deepseek.Init not called")
Supports: &ai.ModelSupports{ }
Multiturn: true,
SystemRole: true, // Define model info, supporting multiturn and system role.
Media: false, // DeepSeek API primarily supports text. mi := ai.ModelInfo{
Tools: false, // Tools not yet supported in this implementation. Label: model.Name,
}, Supports: &ai.ModelSupports{
Versions: []string{}, Multiturn: true,
} SystemRole: true,
if info != nil { Media: false, // DeepSeek API primarily supports text.
mi = *info Tools: false, // Tools not yet supported in this implementation.
} },
Versions: []string{},
meta := &ai.ModelInfo{ }
// Label: "DeepSeek - " + model.Name, if info != nil {
Label: model.Name, mi = *info
Supports: mi.Supports, }
Versions: []string{},
} meta := &ai.ModelInfo{
gen := &generator{model: model, apiKey: d.APIKey} Label: model.Name,
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
modelDef := genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
log.Info().
Str("method", "DeepSeek.DefineModel").
Str("model_name", model.Name).
Msg("Model defined successfully")
return modelDef
} }
// Init initializes the DeepSeek plugin. // Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error { func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
d.mu.Lock() log.Info().Str("method", "DeepSeek.Init").Msg("Initializing DeepSeek plugin")
defer d.mu.Unlock() d.mu.Lock()
if d.initted { defer d.mu.Unlock()
panic("deepseek.Init already called") if d.initted {
} log.Error().Str("method", "DeepSeek.Init").Msg("Plugin already initialized")
return fmt.Errorf("deepseek.Init already called")
if d == nil || d.APIKey == "" { }
return fmt.Errorf("deepseek: need APIKey")
} if d == nil || d.APIKey == "" {
d.initted = true log.Error().Str("method", "DeepSeek.Init").Msg("APIKey is required")
return nil return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
log.Info().Str("method", "DeepSeek.Init").Msg("Initialization successful")
return nil
} }
// generator handles model generation. // generator handles model generation.
type generator struct { type generator struct {
model ModelDefinition model ModelDefinition
apiKey string apiKey string
} }
// generate implements the Genkit model generation interface. // generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
log.Info().
Str("method", "generator.generate").
Str("model_name", g.model.Name).
Int("messages", len(input.Messages)).
Msg("Starting model generation")
if len(input.Messages) == 0 {
log.Error().Str("method", "generator.generate").Msg("Prompt or messages required")
return nil, fmt.Errorf("prompt or messages required")
}
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
log.Debug().Str("method", "generator.generate").Msg("DeepSeek client initialized")
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
}
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
log.Error().
Str("method", "generator.generate").
Str("role", string(msg.Role)).
Msg("Unsupported role")
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
Role: role,
Content: content,
})
log.Debug().
Str("method", "generator.generate").
Str("role", role).
Str("content", content).
Msg("Added message to request")
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Error().
Err(err).
Str("method", "generator.generate").
Msg("Failed to create chat completion")
return nil, fmt.Errorf("create chat completion: %w", err)
}
// stream := cb != nil log.Debug().
if len(input.Messages) == 0 { Str("method", "generator.generate").
return nil, fmt.Errorf("prompt or messages required") Int("choices", len(response.Choices)).
} Msg("Received chat completion response")
// Set up the Deepseek client
// Initialize DeepSeek client. // Create a final response with the merged chunks
client := deepseek.NewClient(g.apiKey) finalResponse := &ai.ModelResponse{
// Create a chat completion request Request: input,
request := &deepseek.ChatCompletionRequest{ FinishReason: ai.FinishReason("stop"),
Model: g.model.Name, Message: &ai.Message{
} Role: ai.RoleModel,
},
for _, msg := range input.Messages { }
role, ok := roleMapping[msg.Role]
if !ok { for _, chunk := range response.Choices {
return nil, fmt.Errorf("unsupported role: %s", msg.Role) log.Debug().
} Str("method", "generator.generate").
content := concatMessageParts(msg.Content) Int("index", chunk.Index).
request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{ Str("content", chunk.Message.Content).
Role: role, Msg("Processing response chunk")
Content: content, p := ai.Part{
}) Text: chunk.Message.Content,
} Kind: ai.PartKind(chunk.Index),
}
// Send the request and handle the response finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
response, err := client.CreateChatCompletion(ctx, request) }
if err != nil {
log.Fatalf("error: %v", err) log.Info().
} Str("method", "generator.generate").
Str("model_name", g.model.Name).
// Print the response Int("content_parts", len(finalResponse.Message.Content)).
fmt.Println("Response:", response.Choices[0].Message.Content) Msg("Model generation completed successfully")
return finalResponse, nil
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
for _, chunk := range response.Choices {
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
} }
// concatMessageParts concatenates message parts into a single string. // concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string { func concatMessageParts(parts []*ai.Part) string {
var sb strings.Builder log.Debug().
for _, part := range parts { Str("method", "concatMessageParts").
if part.IsText() { Int("parts", len(parts)).
sb.WriteString(part.Text) Msg("Concatenating message parts")
} var sb strings.Builder
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them. for _, part := range parts {
} if part.IsText() {
return sb.String() sb.WriteString(part.Text)
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
result := sb.String()
log.Debug().
Str("method", "concatMessageParts").
Str("result", result).
Msg("Concatenation complete")
return result
} }
// package deepseek
// import (
// "context"
// "fmt"
// "log"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// deepseek "github.com/cohesion-org/deepseek-go"
// )
// const provider = "deepseek"
// var (
// mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// // toolSupportedModels = []string{
// // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
// // "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
// // "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
// // "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
// // "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
// // "granite3.3", "command-a", "command-r7b-arabic",
// // }
// roleMapping = map[ai.Role]string{
// ai.RoleUser: deepseek.ChatMessageRoleUser,
// ai.RoleModel: deepseek.ChatMessageRoleAssistant,
// ai.RoleSystem: deepseek.ChatMessageRoleSystem,
// ai.RoleTool: deepseek.ChatMessageRoleTool,
// }
// )
// // DeepSeek holds configuration for the plugin.
// type DeepSeek struct {
// APIKey string // DeepSeek API key
// //ServerAddress string
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the provider name.
// func (d DeepSeek) Name() string {
// return provider
// }
// // ModelDefinition represents a model with its name and type.
// type ModelDefinition struct {
// Name string
// Type string
// }
// // // DefineModel defines a DeepSeek model in Genkit.
// func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// d.mu.Lock()
// defer d.mu.Unlock()
// if !d.initted {
// panic("deepseek.Init not called")
// }
// // Define model info, supporting multiturn and system role.
// mi := ai.ModelInfo{
// Label: model.Name,
// Supports: &ai.ModelSupports{
// Multiturn: true,
// SystemRole: true,
// Media: false, // DeepSeek API primarily supports text.
// Tools: false, // Tools not yet supported in this implementation.
// },
// Versions: []string{},
// }
// if info != nil {
// mi = *info
// }
// meta := &ai.ModelInfo{
// // Label: "DeepSeek - " + model.Name,
// Label: model.Name,
// Supports: mi.Supports,
// Versions: []string{},
// }
// gen := &generator{model: model, apiKey: d.APIKey}
// return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
// }
// // Init initializes the DeepSeek plugin.
// func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
// d.mu.Lock()
// defer d.mu.Unlock()
// if d.initted {
// panic("deepseek.Init already called")
// }
// if d == nil || d.APIKey == "" {
// return fmt.Errorf("deepseek: need APIKey")
// }
// d.initted = true
// return nil
// }
// // generator handles model generation.
// type generator struct {
// model ModelDefinition
// apiKey string
// }
// // generate implements the Genkit model generation interface.
// func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// // stream := cb != nil
// if len(input.Messages) == 0 {
// return nil, fmt.Errorf("prompt or messages required")
// }
// // Set up the Deepseek client
// // Initialize DeepSeek client.
// client := deepseek.NewClient(g.apiKey)
// // Create a chat completion request
// request := &deepseek.ChatCompletionRequest{
// Model: g.model.Name,
// }
// for _, msg := range input.Messages {
// role, ok := roleMapping[msg.Role]
// if !ok {
// return nil, fmt.Errorf("unsupported role: %s", msg.Role)
// }
// content := concatMessageParts(msg.Content)
// request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
// Role: role,
// Content: content,
// })
// }
// // Send the request and handle the response
// response, err := client.CreateChatCompletion(ctx, request)
// if err != nil {
// log.Fatalf("error: %v", err)
// }
// // Print the response
// fmt.Println("Response:", response.Choices[0].Message.Content)
// // Create a final response with the merged chunks
// finalResponse := &ai.ModelResponse{
// Request: input,
// FinishReason: ai.FinishReason("stop"),
// Message: &ai.Message{
// Role: ai.RoleModel,
// },
// }
// for _, chunk := range response.Choices {
// p := ai.Part{
// Text: chunk.Message.Content,
// Kind: ai.PartKind(chunk.Index),
// }
// finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
// }
// return finalResponse, nil // Return the final merged response
// }
// // concatMessageParts concatenates message parts into a single string.
// func concatMessageParts(parts []*ai.Part) string {
// var sb strings.Builder
// for _, part := range parts {
// if part.IsText() {
// sb.WriteString(part.Text)
// }
// // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
// }
// return sb.String()
// }
/* /*
// Choice represents a completion choice generated by the model. // Choice represents a completion choice generated by the model.
......
...@@ -32,181 +32,242 @@ import ( ...@@ -32,181 +32,242 @@ import (
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/rs/zerolog/log"
"github.com/wade-liwei/agentchat/util" "github.com/wade-liwei/agentchat/util"
) )
// Client 知识库客户端 // Client 知识库客户端
type Client struct { type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670" BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
} }
// SpaceRequest 创建空间的请求结构体 // SpaceRequest 创建空间的请求结构体
type SpaceRequest struct { type SpaceRequest struct {
ID int `json:"id"` ID int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
VectorType string `json:"vector_type"` VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"` DomainType string `json:"domain_type"`
Desc string `json:"desc"` Desc string `json:"desc"`
Owner string `json:"owner"` Owner string `json:"owner"`
SpaceID int `json:"space_id"` SpaceID int `json:"space_id"`
} }
// DocumentRequest 添加文档的请求结构体 // DocumentRequest 添加文档的请求结构体
type DocumentRequest struct { type DocumentRequest struct {
DocName string `json:"doc_name"` DocName string `json:"doc_name"`
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
DocType string `json:"doc_type"` DocType string `json:"doc_type"`
DocToken string `json:"doc_token"` DocToken string `json:"doc_token"`
Content string `json:"content"` Content string `json:"content"`
Source string `json:"source"` Source string `json:"source"`
Labels string `json:"labels"` Labels string `json:"labels"`
Questions []string `json:"questions"` Questions []string `json:"questions"`
Metadata map[string]interface{} `json:"metadata"` Metadata map[string]interface{} `json:"metadata"`
} }
// ChunkParameters 分片参数 // ChunkParameters 分片参数
type ChunkParameters struct { type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"` ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"` TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"` SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"` ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"` ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"` Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"` EnableMerge bool `json:"enable_merge"`
} }
// SyncBatchRequest 同步批量处理的请求结构体 // SyncBatchRequest 同步批量处理的请求结构体
type SyncBatchRequest struct { type SyncBatchRequest struct {
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
SpaceID string `json:"space_id"` SpaceID string `json:"space_id"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"` ChunkParameters ChunkParameters `json:"chunk_parameters"`
} }
// NewClient 创建新的客户端实例 // NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client { func NewClient(ip string, port int) *Client {
return &Client{ log.Info().
BaseURL: fmt.Sprintf("http://%s:%d", ip, port), Str("method", "NewClient").
} Str("ip", ip).
Int("port", port).
Msg("Creating new GraphRAG client")
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
} }
// AddSpace 创建知识空间 // AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) { func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL) log.Info().
body, err := json.Marshal(req) Str("method", "Client.AddSpace").
if err != nil { Str("name", req.Name).
return nil, fmt.Errorf("failed to marshal request: %w", err) Str("owner", req.Owner).
} Msg("Adding knowledge space")
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to marshal request")
} return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
client := &http.Client{} log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to create request")
resp, err := client.Do(httpReq) return nil, fmt.Errorf("failed to create request: %w", err)
if err != nil { }
return nil, fmt.Errorf("failed to send request: %w", err)
} httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
return resp, nil
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.AddSpace").
Int("status_code", resp.StatusCode).
Msg("Space addition request completed")
return resp, nil
} }
// AddDocument 添加文档 // AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) { func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID) log.Info().
body, err := json.Marshal(req) Str("method", "Client.AddDocument").
if err != nil { Str("space_id", spaceID).
return nil, fmt.Errorf("failed to marshal request: %w", err) Str("doc_name", req.DocName).
} Msg("Adding document")
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to marshal request")
} return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
client := &http.Client{} log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to create request")
resp, err := client.Do(httpReq) return nil, fmt.Errorf("failed to create request: %w", err)
if err != nil { }
return nil, fmt.Errorf("failed to send request: %w", err)
} httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
return resp, nil
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.AddDocument").
Str("space_id", spaceID).
Int("status_code", resp.StatusCode).
Msg("Document addition request completed")
return resp, nil
} }
// SyncDocumentsRequest defines the request body for the sync documents endpoint. // SyncDocumentsRequest defines the request body for the sync documents endpoint.
type SyncDocumentsRequest struct { type SyncDocumentsRequest struct {
DocIDs []string `json:"doc_ids"` DocIDs []string `json:"doc_ids"`
} }
// SyncDocuments sends a POST request to sync documents for the given spaceID. // SyncDocuments sends a POST request to sync documents for the given spaceID.
func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, err error) { func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, err error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync", c.BaseURL, spaceID) log.Info().
reqBody := SyncDocumentsRequest{ Str("method", "Client.SyncDocuments").
DocIDs: docIDs, Str("space_id", spaceID).
} Strs("doc_ids", docIDs).
body, err := json.Marshal(reqBody) Msg("Syncing documents")
if err != nil { url := fmt.Sprintf("%s/knowledge/%s/document/sync", c.BaseURL, spaceID)
return false, fmt.Errorf("failed to marshal request: %w", err) reqBody := SyncDocumentsRequest{
} DocIDs: docIDs,
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) body, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to create request: %w", err) log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to marshal request")
} return false, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
client := &http.Client{} log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to create request")
resp, err := client.Do(httpReq) return false, fmt.Errorf("failed to create request: %w", err)
if err != nil { }
return false, fmt.Errorf("failed to send request: %w", err)
} httpReq.Header.Set("Accept", "application/json")
defer resp.Body.Close() httpReq.Header.Set("Content-Type", "application/json")
respBody, err := io.ReadAll(resp.Body) client := &http.Client{}
if err != nil { resp, err := client.Do(httpReq)
return false, fmt.Errorf("failed to read response body: %w", err) if err != nil {
} log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to send request")
return false, fmt.Errorf("failed to send request: %w", err)
if resp.StatusCode != http.StatusOK { }
return false, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody)) defer resp.Body.Close()
}
respBody, err := io.ReadAll(resp.Body)
return success, nil if err != nil {
log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to read response body")
return false, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Error().
Str("method", "Client.SyncDocuments").
Int("status_code", resp.StatusCode).
Str("response_body", string(respBody)).
Msg("Sync request failed")
return false, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
}
log.Info().
Str("method", "Client.SyncDocuments").
Str("space_id", spaceID).
Msg("Documents synced successfully")
return true, nil
} }
// SyncBatchDocument 同步批量处理文档 // SyncBatchDocument 同步批量处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) { func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID) log.Info().
body, err := json.Marshal(req) Str("method", "Client.SyncBatchDocument").
if err != nil { Str("space_id", spaceID).
return nil, fmt.Errorf("failed to marshal request: %w", err) Int("requests", len(req)).
} Msg("Syncing batch documents")
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body)) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to marshal request")
} return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json") httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
client := &http.Client{} log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to create request")
resp, err := client.Do(httpReq) return nil, fmt.Errorf("failed to create request: %w", err)
if err != nil { }
return nil, fmt.Errorf("failed to send request: %w", err)
} httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
return resp, nil
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.SyncBatchDocument").
Str("space_id", spaceID).
Int("status_code", resp.StatusCode).
Msg("Batch document sync request completed")
return resp, nil
} }
// The provider used in the registry. // The provider used in the registry.
...@@ -214,305 +275,1040 @@ const provider = "graphrag" ...@@ -214,305 +275,1040 @@ const provider = "graphrag"
// Field names for schema. // Field names for schema.
const ( const (
idField = "id" idField = "id"
textField = "text" textField = "text"
metadataField = "metadata" metadataField = "metadata"
) )
// GraphKnowledge holds configuration for the plugin. // GraphKnowledge holds configuration for the plugin.
type GraphKnowledge struct { type GraphKnowledge struct {
Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670"). Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670").
client *Client // Knowledge client. client *Client // Knowledge client.
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the plugin name. // Name returns the plugin name.
func (k *GraphKnowledge) Name() string { func (k *GraphKnowledge) Name() string {
return provider return provider
} }
// Init initializes the GraphKnowledge plugin. // Init initializes the GraphKnowledge plugin.
func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) { func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if k == nil { log.Info().Str("method", "GraphKnowledge.Init").Msg("Initializing GraphKnowledge plugin")
k = &GraphKnowledge{} if k == nil {
} k = &GraphKnowledge{}
}
k.mu.Lock()
defer k.mu.Unlock() k.mu.Lock()
defer func() { defer k.mu.Unlock()
if err != nil { defer func() {
err = fmt.Errorf("graphrag.Init: %w", err) if err != nil {
} log.Error().Err(err).Str("method", "GraphKnowledge.Init").Msg("Initialization failed")
}() err = fmt.Errorf("graphrag.Init: %w", err)
} else {
if k.initted { log.Info().Str("method", "GraphKnowledge.Init").Msg("Initialization successful")
return errors.New("plugin already initialized") }
} }()
// Load configuration. if k.initted {
addr := k.Addr return errors.New("plugin already initialized")
if addr == "" { }
addr = "54.92.111.204:5670" // Default address.
} // Load configuration.
addr := k.Addr
// Initialize Knowledge client. if addr == "" {
host, port := parseAddr(addr) addr = "54.92.111.204:5670" // Default address.
client := NewClient(host, port) }
k.client = client
k.initted = true // Initialize Knowledge client.
return nil host, port := parseAddr(addr)
client := NewClient(host, port)
k.client = client
k.initted = true
return nil
} }
// parseAddr splits host:port into host and port. // parseAddr splits host:port into host and port.
func parseAddr(addr string) (string, int) { func parseAddr(addr string) (string, int) {
parts := strings.Split(addr, ":") parts := strings.Split(addr, ":")
if len(parts) != 2 { if len(parts) != 2 {
return "54.92.111.204", 5670 log.Warn().
} Str("method", "parseAddr").
port, _ := strconv.Atoi(parts[1]) Str("addr", addr).
return parts[0], port Msg("Invalid address format, using default")
return "54.92.111.204", 5670
}
port, err := strconv.Atoi(parts[1])
if err != nil {
log.Error().
Err(err).
Str("method", "parseAddr").
Str("port", parts[1]).
Msg("Failed to parse port, using default")
return "54.92.111.204", 5670
}
return parts[0], port
} }
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space. // DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) { func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {
log.Info().Str("method", "DefineIndexerAndRetriever").Msg("Defining indexer and retriever")
spaceID := "" spaceID := ""
modelName := "Qwen/Qwen2.5-Coder-32B-Instruct" modelName := "Qwen/Qwen2.5-Coder-32B-Instruct"
k := genkit.LookupPlugin(g, provider) k := genkit.LookupPlugin(g, provider)
if k == nil { if k == nil {
return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?") log.Error().Str("method", "DefineIndexerAndRetriever").Msg("GraphRAG plugin not found")
} return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
knowledge := k.(*GraphKnowledge) }
knowledge := k.(*GraphKnowledge)
ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
if err != nil { ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
return nil, nil, err if err != nil {
} log.Error().Err(err).Str("method", "DefineIndexerAndRetriever").Msg("Failed to create doc store")
return nil, nil, err
indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index) }
retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
return indexer, retriever, nil indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
log.Info().Str("method", "DefineIndexerAndRetriever").Msg("Indexer and retriever defined successfully")
return indexer, retriever, nil
} }
// docStore defines an Indexer and a Retriever. // docStore defines an Indexer and a Retriever.
type docStore struct { type docStore struct {
client *Client client *Client
spaceID string spaceID string
modelName string modelName string
} }
// newDocStore creates a docStore. // newDocStore creates a docStore.
func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) { func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
if k.client == nil { log.Info().
return nil, errors.New("graphrag.Init not called") Str("method", "GraphKnowledge.newDocStore").
} Str("space_id", spaceID).
Str("model_name", modelName).
return &docStore{ Msg("Creating new doc store")
client: k.client, if k.client == nil {
spaceID: spaceID, log.Error().Str("method", "GraphKnowledge.newDocStore").Msg("GraphRAG client not initialized")
modelName: modelName, return nil, errors.New("graphrag.Init not called")
}, nil }
log.Info().Str("method", "GraphKnowledge.newDocStore").Msg("Doc store created successfully")
return &docStore{
client: k.client,
spaceID: spaceID,
modelName: modelName,
}, nil
} }
// Indexer returns the indexer for a space. // Indexer returns the indexer for a space.
func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer { func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
return genkit.LookupIndexer(g, provider, spaceID) log.Info().
Str("method", "Indexer").
Str("space_id", spaceID).
Msg("Looking up indexer")
indexer := genkit.LookupIndexer(g, provider, spaceID)
if indexer == nil {
log.Warn().
Str("method", "Indexer").
Str("space_id", spaceID).
Msg("Indexer not found")
}
return indexer
} }
// Retriever returns the retriever for a space. // Retriever returns the retriever for a space.
func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever { func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
return genkit.LookupRetriever(g, provider, spaceID) log.Info().
Str("method", "Retriever").
Str("space_id", spaceID).
Msg("Looking up retriever")
retriever := genkit.LookupRetriever(g, provider, spaceID)
if retriever == nil {
log.Warn().
Str("method", "Retriever").
Str("space_id", spaceID).
Msg("Retriever not found")
}
return retriever
} }
// generateRandomDocName generates a random alphanumeric string of the specified length. // generateRandomDocName generates a random alphanumeric string of the specified length.
func GenerateRandomDocName(length int) (string, error) { func GenerateRandomDocName(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789" log.Debug().
var result strings.Builder Str("method", "GenerateRandomDocName").
result.Grow(length) Int("length", length).
Msg("Generating random document name")
for i := 0; i < length; i++ { const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) var result strings.Builder
if err != nil { result.Grow(length)
return "", fmt.Errorf("failed to generate random index: %w", err)
} for i := 0; i < length; i++ {
result.WriteByte(charset[idx.Int64()]) idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
} if err != nil {
log.Error().
return result.String(), nil Err(err).
Str("method", "GenerateRandomDocName").
Msg("Failed to generate random index")
return "", fmt.Errorf("failed to generate random index: %w", err)
}
result.WriteByte(charset[idx.Int64()])
}
docName := result.String()
log.Debug().
Str("method", "GenerateRandomDocName").
Str("doc_name", docName).
Msg("Generated document name")
return docName, nil
} }
// ParseJSONResponse parses a JSON byte slice and extracts the success boolean and data fields as a string. // ParseJSONResponse parses a JSON byte slice and extracts the success boolean and data fields as a string.
func ParseJSONResponse(jsonBytes []byte) (success bool, data string, err error) { func ParseJSONResponse(jsonBytes []byte) (success bool, data string, err error) {
// Define struct to capture only the needed fields log.Debug().
type jsonResponse struct { Str("method", "ParseJSONResponse").
Success bool `json:"success"` Str("json", string(jsonBytes)).
Data int `json:"data"` // Use string to capture JSON string data Msg("Parsing JSON response")
} // Define struct to capture only the needed fields
type jsonResponse struct {
var resp jsonResponse Success bool `json:"success"`
if err := json.Unmarshal(jsonBytes, &resp); err != nil { Data int `json:"data"`
return false, "", fmt.Errorf("failed to unmarshal JSON: %w", err) }
}
var resp jsonResponse
return resp.Success, fmt.Sprintf("%d", resp.Data), nil if err := json.Unmarshal(jsonBytes, &resp); err != nil {
log.Error().
Err(err).
Str("method", "ParseJSONResponse").
Msg("Failed to unmarshal JSON")
return false, "", fmt.Errorf("failed to unmarshal JSON: %w", err)
}
dataStr := fmt.Sprintf("%d", resp.Data)
log.Debug().
Str("method", "ParseJSONResponse").
Bool("success", resp.Success).
Str("data", dataStr).
Msg("Parsed JSON response")
return resp.Success, dataStr, nil
} }
type IndexReqOption struct { type IndexReqOption struct {
UserId string UserId string
UserName string UserName string
} }
const DocNameKey = "doc_name" const DocNameKey = "doc_name"
// Index implements the Indexer.Index method. // Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 { log.Info().
return nil Str("method", "docStore.Index").
} Str("space_id", ds.spaceID).
Int("documents", len(req.Documents)).
// Type-assert req.Options to IndexReqOption Msg("Starting index operation")
opt, ok := req.Options.(*IndexReqOption)
if !ok { if len(req.Documents) == 0 {
return fmt.Errorf("invalid options type: got %T, want *IndexReqOption", req.Options) log.Debug().
} Str("method", "docStore.Index").
Str("space_id", ds.spaceID).
// Validate required fields Msg("No documents to index")
if opt.UserId == "" { return nil
return fmt.Errorf("UserId is required in IndexReqOption") }
}
if opt.UserName == "" { // Type-assert req.Options to IndexReqOption
return fmt.Errorf("UserName is required in IndexReqOption") opt, ok := req.Options.(*IndexReqOption)
} if !ok {
log.Error().
// Create knowledge space Str("method", "docStore.Index").
spaceReq := SpaceRequest{ Str("options_type", fmt.Sprintf("%T", req.Options)).
Name: opt.UserId, Msg("Invalid options type")
VectorType: "KnowledgeGraph", return fmt.Errorf("invalid options type: got %T, want *IndexReqOption", req.Options)
DomainType: "Normal", }
Desc: opt.UserName,
Owner: opt.UserId, // Validate required fields
} if opt.UserId == "" {
resp, err := ds.client.AddSpace(spaceReq) log.Error().Str("method", "docStore.Index").Msg("UserId is required")
if err != nil { return fmt.Errorf("UserId is required in IndexReqOption")
return fmt.Errorf("add space: %w", err) }
} if opt.UserName == "" {
defer resp.Body.Close() log.Error().Str("method", "docStore.Index").Msg("UserName is required")
if resp.StatusCode != http.StatusOK { return fmt.Errorf("UserName is required in IndexReqOption")
body, _ := io.ReadAll(resp.Body) }
return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
} // Create knowledge space
spaceReq := SpaceRequest{
fmt.Println("space ok") Name: opt.UserId,
VectorType: "KnowledgeGraph",
spaceId := opt.UserId DomainType: "Normal",
Desc: opt.UserName,
// Index each document Owner: opt.UserId,
for i, doc := range req.Documents { }
// Use DocName from options, fall back to random name if empty resp, err := ds.client.AddSpace(spaceReq)
docName := "" if err != nil {
if v, ok := doc.Metadata[DocNameKey]; ok { log.Error().Err(err).Str("method", "docStore.Index").Msg("Failed to add space")
if str, isString := v.(string); isString { return fmt.Errorf("add space: %w", err)
docName = str }
} else { defer resp.Body.Close()
return fmt.Errorf("must provide doc_name str value in metadata") if resp.StatusCode != http.StatusOK {
} body, _ := io.ReadAll(resp.Body)
} else { log.Error().
return fmt.Errorf("must provide doc_name key in metadata") Str("method", "docStore.Index").
} Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
fmt.Println("docName: ", docName) Msg("Add space failed")
return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
// Add document }
var sb strings.Builder
for _, p := range doc.Content { log.Info().Str("method", "docStore.Index").Str("space_id", opt.UserId).Msg("Space created successfully")
if p.IsText() { spaceId := opt.UserId
sb.WriteString(p.Text)
} // Index each document
} for i, doc := range req.Documents {
text := sb.String() // Use DocName from metadata
fmt.Println("text: ", text) docName, ok := doc.Metadata[DocNameKey].(string)
docReq := DocumentRequest{ if !ok {
DocName: docName, log.Error().
Source: "api", Str("method", "docStore.Index").
DocType: "TEXT", Int("index", i).
Content: text, Msg("Missing doc_name in metadata")
Labels: "", return fmt.Errorf("must provide doc_name key in metadata")
Metadata: doc.Metadata, }
} if docName == "" {
resp, err := ds.client.AddDocument(spaceId, docReq) log.Error().
if err != nil { Str("method", "docStore.Index").
return fmt.Errorf("add document %d: %w", i+1, err) Int("index", i).
} Msg("doc_name is empty")
body, err := io.ReadAll(resp.Body) return fmt.Errorf("must provide non-empty doc_name str value in metadata")
if err != nil { }
resp.Body.Close()
return fmt.Errorf("read add document response %d: %w", i+1, err) log.Debug().
} Str("method", "docStore.Index").
defer resp.Body.Close() Int("index", i).
Str("doc_name", docName).
if resp.StatusCode != http.StatusOK { Msg("Processing document")
return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
} // Add document
var sb strings.Builder
// Parse AddDocument response for _, p := range doc.Content {
ok, idx, err := ParseJSONResponse(body) if p.IsText() {
if err != nil { sb.WriteString(p.Text)
return fmt.Errorf("parse add document response %d: %w", i+1, err) }
} }
if !ok { text := sb.String()
return fmt.Errorf("add document %d failed: response success=false, data=%s", i+1, idx) log.Debug().
} Str("method", "docStore.Index").
fmt.Println("document ok", string(body), idx) Int("index", i).
Str("text", text).
// Sync document Msg("Extracted document text")
_, err = ds.client.SyncDocuments(spaceId, []string{idx})
if err != nil { docReq := DocumentRequest{
return fmt.Errorf("sync document %d: %w", i+1, err) DocName: docName,
} Source: "api",
} DocType: "TEXT",
Content: text,
return nil Labels: "",
Metadata: doc.Metadata,
}
resp, err := ds.client.AddDocument(spaceId, docReq)
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to add document")
return fmt.Errorf("add document %d: %w", i+1, err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to read add document response")
return fmt.Errorf("read add document response %d: %w", i+1, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Error().
Str("method", "docStore.Index").
Int("index", i+1).
Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
Msg("Add document failed")
return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
}
// Parse AddDocument response
ok, idx, err := ParseJSONResponse(body)
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to parse add document")
return fmt.Errorf("parse add document response %d: %w", i+1, err)
}
if !ok {
log.Error().
Str("method", "docStore.Index").
Int("index", i+1).
Str("data", idx).
Msg("Add document response indicated failure")
return fmt.Errorf("add document %d failed: response success=false, data=%s", i+1, idx)
}
log.Info().
Str("method", "docStore.Index").
Int("index", i+1).
Str("doc_id", idx).
Msg("Document added successfully")
// Sync document
_, err = ds.client.SyncDocuments(spaceId, []string{idx})
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to sync document")
return fmt.Errorf("sync document %d: %w", i+1, err)
}
}
log.Info().
Str("method", "docStore.Index").
Str("space_id", ds.spaceID).
Int("documents", len(req.Documents)).
Msg("Index operation completed successfully")
return nil
}
// ChatRequest defines the request structure for chat completions.
type ChatRequest struct {
Model string `json:"model"`
Messages string `json:"messages"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
N int `json:"n"`
MaxTokens int64 `json:"max_tokens"`
Stream bool `json:"stream"`
RepetitionPenalty float64 `json:"repetition_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
ChatMode string `json:"chat_mode"`
ChatParam string `json:"chat_param"`
EnableVis bool `json:"enable_vis"`
}
// ChatResponse defines the response structure from the API.
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent interface{} `json:"reasoning_content"`
} `json:"message"`
FinishReason interface{} `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
// Assuming ai.Part has a Text() method or Text field to get string content.
func partsToString(parts []*ai.Part) string {
log.Debug().
Str("method", "partsToString").
Int("parts", len(parts)).
Msg("Converting parts to string")
var texts []string
for _, part := range parts {
texts = append(texts, part.Text)
}
result := strings.Join(texts, " ")
log.Debug().
Str("method", "partsToString").
Str("result", result).
Msg("Conversion complete")
return result
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
log.Info().
Str("method", "docStore.Retrieve").
Str("space_id", ds.spaceID).
Msg("Starting retrieve operation")
// Format query for retrieval.
queryContent := partsToString(req.Query.Content)
queryText := fmt.Sprintf("Search for: %s", queryContent)
log.Debug().
Str("method", "docStore.Retrieve").
Str("query", queryText).
Msg("Formatted query")
if req.Query.Metadata == nil {
log.Error().
Str("method", "docStore.Retrieve").
Str("metadata_type", fmt.Sprintf("%T", req.Query.Metadata)).
Msg("Query metadata is nil")
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
for k, v := range req.Query.Metadata {
log.Debug().
Str("method", "docStore.Retrieve").
Str("key", k).
Interface("value", v).
Msg("Metadata entry")
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Msg("Missing username in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Msg("Missing user_id in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
// Prepare request for chat completions endpoint.
url := fmt.Sprintf("%s/api/v2/chat/completions", ds.client.BaseURL)
chatReq := ChatRequest{
Model: ds.modelName,
Messages: queryText,
Temperature: 0.7,
TopP: 1,
TopK: -1,
N: 1,
MaxTokens: 0,
Stream: false,
RepetitionPenalty: 1,
FrequencyPenalty: 0,
PresencePenalty: 0,
ChatMode: "chat_knowledge",
ChatParam: userId,
EnableVis: true,
}
log.Debug().
Str("method", "docStore.Retrieve").
Str("url", url).
Interface("chat_request", chatReq).
Msg("Preparing chat completion request")
body, err := json.Marshal(chatReq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to marshal chat request")
return nil, fmt.Errorf("marshal chat request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to create chat request")
return nil, fmt.Errorf("create chat request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to send chat request")
return nil, fmt.Errorf("send chat request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Error().
Str("method", "docStore.Retrieve").
Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
Msg("Chat completion failed")
return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var chatResp ChatResponse
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to read response body")
return nil, fmt.Errorf("read chat response body: %w", err)
}
if err := json.Unmarshal(respBody, &chatResp); err != nil {
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("raw_response", string(respBody)).
Msg("Failed to decode chat response")
return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(respBody))
}
log.Debug().
Str("method", "docStore.Retrieve").
Interface("chat_response", chatResp).
Msg("Parsed chat response")
// Convert response to ai.Document
var docs []*ai.Document
if len(chatResp.Choices) > 0 {
content := chatResp.Choices[0].Message.Content
metadata := map[string]interface{}{
util.UserIdKey: userId,
util.UserNameKey: userName,
}
aiDoc := ai.DocumentFromText(content, metadata)
docs = append(docs, aiDoc)
log.Debug().
Str("method", "docStore.Retrieve").
Str("content", content).
Interface("metadata", metadata).
Msg("Created document from response")
}
log.Info().
Str("method", "docStore.Retrieve").
Str("space_id", ds.spaceID).
Int("documents", len(docs)).
Msg("Retrieve operation completed successfully")
return &ai.RetrieverResponse{
Documents: docs,
}, nil
} }
// type IndexReqOption struct{
// UserId string
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package graphrag
// import (
// "bytes"
// "context"
// "crypto/rand"
// "encoding/json"
// "errors"
// "fmt"
// "io"
// "math/big"
// "net/http"
// "strconv"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/wade-liwei/agentchat/util"
// )
// // Client 知识库客户端
// type Client struct {
// BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
// }
// // SpaceRequest 创建空间的请求结构体
// type SpaceRequest struct {
// ID int `json:"id"`
// Name string `json:"name"`
// VectorType string `json:"vector_type"`
// DomainType string `json:"domain_type"`
// Desc string `json:"desc"`
// Owner string `json:"owner"`
// SpaceID int `json:"space_id"`
// }
// // DocumentRequest 添加文档的请求结构体
// type DocumentRequest struct {
// DocName string `json:"doc_name"`
// DocID int `json:"doc_id"`
// DocType string `json:"doc_type"`
// DocToken string `json:"doc_token"`
// Content string `json:"content"`
// Source string `json:"source"`
// Labels string `json:"labels"`
// Questions []string `json:"questions"`
// Metadata map[string]interface{} `json:"metadata"`
// }
// // ChunkParameters 分片参数
// type ChunkParameters struct {
// ChunkStrategy string `json:"chunk_strategy"`
// TextSplitter string `json:"text_splitter"`
// SplitterType string `json:"splitter_type"`
// ChunkSize int `json:"chunk_size"`
// ChunkOverlap int `json:"chunk_overlap"`
// Separator string `json:"separator"`
// EnableMerge bool `json:"enable_merge"`
// }
// // SyncBatchRequest 同步批量处理的请求结构体
// type SyncBatchRequest struct {
// DocID int `json:"doc_id"`
// SpaceID string `json:"space_id"`
// ModelName string `json:"model_name"`
// ChunkParameters ChunkParameters `json:"chunk_parameters"`
// }
// // NewClient 创建新的客户端实例
// func NewClient(ip string, port int) *Client {
// return &Client{
// BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
// }
// }
// // AddSpace 创建知识空间
// func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // AddDocument 添加文档
// func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // SyncDocumentsRequest defines the request body for the sync documents endpoint.
// type SyncDocumentsRequest struct {
// DocIDs []string `json:"doc_ids"`
// }
// // SyncDocuments sends a POST request to sync documents for the given spaceID.
// func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, err error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/sync", c.BaseURL, spaceID)
// reqBody := SyncDocumentsRequest{
// DocIDs: docIDs,
// }
// body, err := json.Marshal(reqBody)
// if err != nil {
// return false, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return false, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return false, fmt.Errorf("failed to send request: %w", err)
// }
// defer resp.Body.Close()
// respBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return false, fmt.Errorf("failed to read response body: %w", err)
// }
// if resp.StatusCode != http.StatusOK {
// return false, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
// }
// return success, nil
// }
// // SyncBatchDocument 同步批量处理文档
// func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // The provider used in the registry.
// const provider = "graphrag"
// // Field names for schema.
// const (
// idField = "id"
// textField = "text"
// metadataField = "metadata"
// )
// // GraphKnowledge holds configuration for the plugin.
// type GraphKnowledge struct {
// Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670").
// client *Client // Knowledge client.
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the plugin name.
// func (k *GraphKnowledge) Name() string {
// return provider
// }
// // Init initializes the GraphKnowledge plugin.
// func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
// if k == nil {
// k = &GraphKnowledge{}
// }
// k.mu.Lock()
// defer k.mu.Unlock()
// defer func() {
// if err != nil {
// err = fmt.Errorf("graphrag.Init: %w", err)
// }
// }()
// if k.initted {
// return errors.New("plugin already initialized")
// }
// // Load configuration.
// addr := k.Addr
// if addr == "" {
// addr = "54.92.111.204:5670" // Default address.
// }
// // Initialize Knowledge client.
// host, port := parseAddr(addr)
// client := NewClient(host, port)
// k.client = client
// k.initted = true
// return nil
// }
// // parseAddr splits host:port into host and port.
// func parseAddr(addr string) (string, int) {
// parts := strings.Split(addr, ":")
// if len(parts) != 2 {
// return "54.92.111.204", 5670
// }
// port, _ := strconv.Atoi(parts[1])
// return parts[0], port
// }
// // DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
// func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {
// spaceID := ""
// modelName := "Qwen/Qwen2.5-Coder-32B-Instruct"
// k := genkit.LookupPlugin(g, provider)
// if k == nil {
// return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
// }
// knowledge := k.(*GraphKnowledge)
// ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
// if err != nil {
// return nil, nil, err
// }
// indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
// retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
// return indexer, retriever, nil
// }
// // docStore defines an Indexer and a Retriever.
// type docStore struct {
// client *Client
// spaceID string
// modelName string
// }
// // newDocStore creates a docStore.
// func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
// if k.client == nil {
// return nil, errors.New("graphrag.Init not called")
// }
// return &docStore{
// client: k.client,
// spaceID: spaceID,
// modelName: modelName,
// }, nil
// }
// // Indexer returns the indexer for a space.
// func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
// return genkit.LookupIndexer(g, provider, spaceID)
// }
// // Retriever returns the retriever for a space.
// func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
// return genkit.LookupRetriever(g, provider, spaceID)
// }
// // generateRandomDocName generates a random alphanumeric string of the specified length.
// func GenerateRandomDocName(length int) (string, error) {
// const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
// var result strings.Builder
// result.Grow(length)
// for i := 0; i < length; i++ {
// idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
// if err != nil {
// return "", fmt.Errorf("failed to generate random index: %w", err)
// }
// result.WriteByte(charset[idx.Int64()])
// }
// return result.String(), nil
// }
// // ParseJSONResponse parses a JSON byte slice and extracts the success boolean and data fields as a string.
// func ParseJSONResponse(jsonBytes []byte) (success bool, data string, err error) {
// // Define struct to capture only the needed fields
// type jsonResponse struct {
// Success bool `json:"success"`
// Data int `json:"data"` // Use string to capture JSON string data
// }
// var resp jsonResponse
// if err := json.Unmarshal(jsonBytes, &resp); err != nil {
// return false, "", fmt.Errorf("failed to unmarshal JSON: %w", err)
// }
// return resp.Success, fmt.Sprintf("%d", resp.Data), nil
// }
// type IndexReqOption struct {
// UserId string
// UserName string // UserName string
// DocName string
// } // }
// const DocNameKey = "doc_name"
// // Index implements the Indexer.Index method. // // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { // func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 { // if len(req.Documents) == 0 {
// return nil // return nil
// } // }
// req.Options // // Type-assert req.Options to IndexReqOption
// opt, ok := req.Options.(*IndexReqOption)
// if !ok {
// return fmt.Errorf("invalid options type: got %T, want *IndexReqOption", req.Options)
// }
// userid := "" // // Validate required fields
// usernmae := "" // if opt.UserId == "" {
// for _, doc := range req.Documents { // return fmt.Errorf("UserId is required in IndexReqOption")
// if v, ok := doc.Metadata["user_id"]; ok { // }
// if str, isString := v.(string); isString { // if opt.UserName == "" {
// userid = str // return fmt.Errorf("UserName is required in IndexReqOption")
// }
// }
// if v, ok := doc.Metadata["username"]; ok {
// if str, isString := v.(string); isString {
// usernmae = str
// }
// }
// } // }
// // Create knowledge space. // // Create knowledge space
// spaceReq := SpaceRequest{ // spaceReq := SpaceRequest{
// Name: userid, // Name: opt.UserId,
// VectorType: "KnowledgeGraph", // VectorType: "KnowledgeGraph",
// DomainType: "Normal", // DomainType: "Normal",
// Desc: usernmae, // Desc: opt.UserName,
// Owner: userid, // Owner: opt.UserId,
// } // }
// resp, err := ds.client.AddSpace(spaceReq) // resp, err := ds.client.AddSpace(spaceReq)
// if err != nil { // if err != nil {
...@@ -526,37 +1322,25 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -526,37 +1322,25 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// fmt.Println("space ok") // fmt.Println("space ok")
// spaceId := userid // spaceId := opt.UserId
// // Index each document. // // Index each document
// for i, doc := range req.Documents { // for i, doc := range req.Documents {
// // Use DocName from options, fall back to random name if empty
// docName := "" // docName := ""
// if v, ok := doc.Metadata["doc_name"]; ok { // if v, ok := doc.Metadata[DocNameKey]; ok {
// if str, isString := v.(string); isString { // if str, isString := v.(string); isString {
// docName = str // docName = str
// } else { // } else {
// return fmt.Errorf("must provide doc_name str value in metadata")
// // Generate random docName.
// var err error
// docName, err = generateRandomDocName(8)
// if err != nil {
// return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// }
// } // }
// } else { // } else {
// // Generate random docName. // return fmt.Errorf("must provide doc_name key in metadata")
// var err error
// docName, err = generateRandomDocName(8)
// if err != nil {
// return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// }
// } // }
// fmt.Println("docName: ", docName) // fmt.Println("docName: ", docName)
// // Add document. // // Add document
// var sb strings.Builder // var sb strings.Builder
// for _, p := range doc.Content { // for _, p := range doc.Content {
// if p.IsText() { // if p.IsText() {
...@@ -564,128 +1348,419 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -564,128 +1348,419 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// } // }
// } // }
// text := sb.String() // text := sb.String()
// fmt.Println("text: ",text) // fmt.Println("text: ", text)
// docReq := DocumentRequest{ // docReq := DocumentRequest{
// DocName: docName, // DocName: docName,
// Source: "api", // Source: "api",
// DocType: "TEXT", // DocType: "TEXT",
// Content: text, // Content: text,
// Labels: "", // Labels: "",
// // Questions: []string{},
// Metadata: doc.Metadata, // Metadata: doc.Metadata,
// } // }
// resp, err := ds.client.AddDocument(spaceId, docReq) // resp, err := ds.client.AddDocument(spaceId, docReq)
// if err != nil { // if err != nil {
// return fmt.Errorf("add document %d: %w", i+1, err) // return fmt.Errorf("add document %d: %w", i+1, err)
// } // }
// body, _ := io.ReadAll(resp.Body) // body, err := io.ReadAll(resp.Body)
// if err != nil {
// resp.Body.Close()
// return fmt.Errorf("read add document response %d: %w", i+1, err)
// }
// defer resp.Body.Close() // defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// if resp.StatusCode != http.StatusOK {
// return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body)) // return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
// } // }
// // Parse AddDocument response
// ok, idx, err := ParseJSONResponse(body) // ok, idx, err := ParseJSONResponse(body)
// if err != nil { // if err != nil {
// return fmt.Errorf("ParseJSONResponse %d: %w", i+1, err) // return fmt.Errorf("parse add document response %d: %w", i+1, err)
// } // }
// if !ok {
// if !ok{ // return fmt.Errorf("add document %d failed: response success=false, data=%s", i+1, idx)
// return fmt.Errorf("ParseJSONResponse body %d: %w", i+1, err)
// } // }
// fmt.Println("document ok",string(body),idx) // fmt.Println("document ok", string(body), idx)
// ok ,err =ds.client.SyncDocuments(spaceId,[]string{idx}) // // Sync document
// _, err = ds.client.SyncDocuments(spaceId, []string{idx})
// if err != nil{ // if err != nil {
// return err // return fmt.Errorf("sync document %d: %w", i+1, err)
// } // }
// } // }
// return nil // return nil
// } // }
// // RetrieverOptions for Knowledge retrieval. // // type IndexReqOption struct{
// type RetrieverOptions struct { // // UserId string
// Count int `json:"count,omitempty"` // Max documents to retrieve. // // UserName string
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). // // DocName string
// // }
// // // Index implements the Indexer.Index method.
// // func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// // if len(req.Documents) == 0 {
// // return nil
// // }
// // req.Options
// // userid := ""
// // usernmae := ""
// // for _, doc := range req.Documents {
// // if v, ok := doc.Metadata["user_id"]; ok {
// // if str, isString := v.(string); isString {
// // userid = str
// // }
// // }
// // if v, ok := doc.Metadata["username"]; ok {
// // if str, isString := v.(string); isString {
// // usernmae = str
// // }
// // }
// // }
// // // Create knowledge space.
// // spaceReq := SpaceRequest{
// // Name: userid,
// // VectorType: "KnowledgeGraph",
// // DomainType: "Normal",
// // Desc: usernmae,
// // Owner: userid,
// // }
// // resp, err := ds.client.AddSpace(spaceReq)
// // if err != nil {
// // return fmt.Errorf("add space: %w", err)
// // }
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // body, _ := io.ReadAll(resp.Body)
// // return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
// // }
// // fmt.Println("space ok")
// // spaceId := userid
// // // Index each document.
// // for i, doc := range req.Documents {
// // docName := ""
// // if v, ok := doc.Metadata["doc_name"]; ok {
// // if str, isString := v.(string); isString {
// // docName = str
// // } else {
// // // Generate random docName.
// // var err error
// // docName, err = generateRandomDocName(8)
// // if err != nil {
// // return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// // }
// // }
// // } else {
// // // Generate random docName.
// // var err error
// // docName, err = generateRandomDocName(8)
// // if err != nil {
// // return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// // }
// // }
// // fmt.Println("docName: ", docName)
// // // Add document.
// // var sb strings.Builder
// // for _, p := range doc.Content {
// // if p.IsText() {
// // sb.WriteString(p.Text)
// // }
// // }
// // text := sb.String()
// // fmt.Println("text: ",text)
// // docReq := DocumentRequest{
// // DocName: docName,
// // Source: "api",
// // DocType: "TEXT",
// // Content: text,
// // Labels: "",
// // // Questions: []string{},
// // Metadata: doc.Metadata,
// // }
// // resp, err := ds.client.AddDocument(spaceId, docReq)
// // if err != nil {
// // return fmt.Errorf("add document %d: %w", i+1, err)
// // }
// // body, _ := io.ReadAll(resp.Body)
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
// // }
// // ok, idx, err := ParseJSONResponse(body)
// // if err != nil {
// // return fmt.Errorf("ParseJSONResponse %d: %w", i+1, err)
// // }
// // if !ok{
// // return fmt.Errorf("ParseJSONResponse body %d: %w", i+1, err)
// // }
// // fmt.Println("document ok",string(body),idx)
// // ok ,err =ds.client.SyncDocuments(spaceId,[]string{idx})
// // if err != nil{
// // return err
// // }
// // }
// // return nil
// // }
// // // RetrieverOptions for Knowledge retrieval.
// // type RetrieverOptions struct {
// // Count int `json:"count,omitempty"` // Max documents to retrieve.
// // MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// // }
// // // Assuming ai.Part has a Text() method or Text field to get the string content
// // func partsToString(parts []*ai.Part) string {
// // var texts []string
// // for _, part := range parts {
// // // Adjust this based on the actual ai.Part structure
// // // If ai.Part has a Text() method:
// // texts = append(texts, part.Text)
// // // OR if ai.Part has a Text field:
// // // texts = append(texts, part.Text)
// // }
// // return strings.Join(texts, " ")
// // }
// // // Retrieve implements the Retriever.Retrieve method.
// // func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// // // count := 3
// // // metricTypeStr := "L2"
// // // if req.Options != nil {
// // // ropt, ok := req.Options.(*RetrieverOptions)
// // // if !ok {
// // // return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// // // }
// // // if ropt.Count > 0 {
// // // count = ropt.Count
// // // }
// // // if ropt.MetricType != "" {
// // // metricTypeStr = ropt.MetricType
// // // }
// // // }
// // queryContent := partsToString(req.Query.Content)
// // // Format query for retrieval.
// // queryText := fmt.Sprintf("Search for: %s", queryContent)
// // username := "Alice" // Default, override if metadata available.
// // if req.Query.Metadata != nil {
// // if uname, ok := req.Query.Metadata["username"].(string); ok {
// // username = uname
// // }
// // }
// // // Prepare request for chat completions endpoint.
// // url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL)
// // chatReq := struct {
// // ConvUID string `json:"conv_uid"`
// // UserInput string `json:"user_input"`
// // UserName string `json:"user_name"`
// // ChatMode string `json:"chat_mode"`
// // AppCode string `json:"app_code"`
// // Temperature float32 `json:"temperature"`
// // MaxNewTokens int `json:"max_new_tokens"`
// // SelectParam string `json:"select_param"`
// // ModelName string `json:"model_name"`
// // Incremental bool `json:"incremental"`
// // SysCode string `json:"sys_code"`
// // PromptCode string `json:"prompt_code"`
// // ExtInfo map[string]interface{} `json:"ext_info"`
// // }{
// // ConvUID: "",
// // UserInput: queryText,
// // UserName: username,
// // ChatMode: "",
// // AppCode: "",
// // Temperature: 0.5,
// // MaxNewTokens: 4000,
// // SelectParam: "",
// // ModelName: ds.modelName,
// // Incremental: false,
// // SysCode: "",
// // PromptCode: "",
// // ExtInfo: map[string]interface{}{
// // "space_id": ds.spaceID,
// // //"k": count,
// // },
// // }
// // body, err := json.Marshal(chatReq)
// // if err != nil {
// // return nil, fmt.Errorf("marshal chat request: %w", err)
// // }
// // httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
// // if err != nil {
// // return nil, fmt.Errorf("create chat request: %w", err)
// // }
// // httpReq.Header.Set("Accept", "application/json")
// // httpReq.Header.Set("Content-Type", "application/json")
// // client := &http.Client{}
// // resp, err := client.Do(httpReq)
// // if err != nil {
// // return nil, fmt.Errorf("send chat request: %w", err)
// // }
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // body, _ := io.ReadAll(resp.Body)
// // return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
// // }
// // // Parse response
// // var chatResp struct {
// // Success bool `json:"success"`
// // Data struct {
// // Answer []struct {
// // Content string `json:"content"`
// // DocID int `json:"doc_id"`
// // Score float64 `json:"score"`
// // Metadata map[string]interface{} `json:"metadata_map"`
// // } `json:"answer"`
// // } `json:"data"`
// // }
// // if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
// // return nil, fmt.Errorf("decode chat response: %w", err)
// // }
// // var docs []*ai.Document
// // for _, doc := range chatResp.Data.Answer {
// // metadata := doc.Metadata
// // if metadata == nil {
// // metadata = make(map[string]interface{})
// // }
// // // Ensure metadata includes user_id and username.
// // if _, ok := metadata["user_id"]; !ok {
// // metadata["user_id"] = "user123"
// // }
// // if _, ok := metadata["username"]; !ok {
// // metadata["username"] = username
// // }
// // aiDoc := ai.DocumentFromText(doc.Content, metadata)
// // docs = append(docs, aiDoc)
// // }
// // return &ai.RetrieverResponse{
// // Documents: docs,
// // }, nil
// // }
// // ChatRequest 定义请求结构体,匹配单元测试的 curl 请求
// type ChatRequest struct {
// Model string `json:"model"`
// Messages string `json:"messages"`
// Temperature float64 `json:"temperature"`
// TopP float64 `json:"top_p"`
// TopK int `json:"top_k"`
// N int `json:"n"`
// MaxTokens int `json:"max_tokens"`
// Stream bool `json:"stream"`
// RepetitionPenalty float64 `json:"repetition_penalty"`
// FrequencyPenalty float64 `json:"frequency_penalty"`
// PresencePenalty float64 `json:"presence_penalty"`
// ChatMode string `json:"chat_mode"`
// ChatParam string `json:"chat_param"`
// EnableVis bool `json:"enable_vis"`
// }
// // ChatResponse 定义响应结构体,匹配单元测试的 API 响应
// type ChatResponse struct {
// ID string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Model string `json:"model"`
// Choices []struct {
// Index int `json:"index"`
// Message struct {
// Role string `json:"role"`
// Content string `json:"content"`
// ReasoningContent interface{} `json:"reasoning_content"`
// } `json:"message"`
// FinishReason interface{} `json:"finish_reason"`
// } `json:"choices"`
// Usage struct {
// PromptTokens int `json:"prompt_tokens"`
// TotalTokens int `json:"total_tokens"`
// CompletionTokens int `json:"completion_tokens"`
// } `json:"usage"`
// } // }
// // Assuming ai.Part has a Text() method or Text field to get the string content // // Assuming ai.Part has a Text() method or Text field to get the string content
// func partsToString(parts []*ai.Part) string { // func partsToString(parts []*ai.Part) string {
// var texts []string // var texts []string
// for _, part := range parts { // for _, part := range parts {
// // Adjust this based on the actual ai.Part structure // // Adjust this based on the actual ai.Part structure
// // If ai.Part has a Text() method: // // If ai.Part has a Text() method:
// texts = append(texts, part.Text) // texts = append(texts, part.Text)
// // OR if ai.Part has a Text field: // // OR if ai.Part has a Text field:
// // texts = append(texts, part.Text) // // texts = append(texts, part.Text)
// } // }
// return strings.Join(texts, " ") // return strings.Join(texts, " ")
// } // }
// // Retrieve implements the Retriever.Retrieve method. // // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { // func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// // count := 3
// // metricTypeStr := "L2"
// // if req.Options != nil {
// // ropt, ok := req.Options.(*RetrieverOptions)
// // if !ok {
// // return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// // }
// // if ropt.Count > 0 {
// // count = ropt.Count
// // }
// // if ropt.MetricType != "" {
// // metricTypeStr = ropt.MetricType
// // }
// // }
// queryContent := partsToString(req.Query.Content)
// // Format query for retrieval. // // Format query for retrieval.
// queryContent := partsToString(req.Query.Content)
// queryText := fmt.Sprintf("Search for: %s", queryContent) // queryText := fmt.Sprintf("Search for: %s", queryContent)
// username := "Alice" // Default, override if metadata available. // if req.Query.Metadata == nil {
// if req.Query.Metadata != nil { // // If ok, we don't use the User struct since the requirement is to error on non-nil
// if uname, ok := req.Query.Metadata["username"].(string); ok { // return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// username = uname // }
// }
// for k, v := range req.Query.Metadata {
// fmt.Println("k", k, "v", v)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// } // }
// // Prepare request for chat completions endpoint. // // Prepare request for chat completions endpoint.
// url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL) // url := fmt.Sprintf("%s/api/v2/chat/completions", ds.client.BaseURL)
// chatReq := struct { // chatReq := ChatRequest{
// ConvUID string `json:"conv_uid"` // Model: ds.modelName,
// UserInput string `json:"user_input"` // Messages: queryText,
// UserName string `json:"user_name"` // Temperature: 0.7,
// ChatMode string `json:"chat_mode"` // TopP: 1,
// AppCode string `json:"app_code"` // TopK: -1,
// Temperature float32 `json:"temperature"` // N: 1,
// MaxNewTokens int `json:"max_new_tokens"` // MaxTokens: 0,
// SelectParam string `json:"select_param"` // Stream: false,
// ModelName string `json:"model_name"` // RepetitionPenalty: 1,
// Incremental bool `json:"incremental"` // FrequencyPenalty: 0,
// SysCode string `json:"sys_code"` // PresencePenalty: 0,
// PromptCode string `json:"prompt_code"` // ChatMode: "chat_knowledge",
// ExtInfo map[string]interface{} `json:"ext_info"` // ChatParam: userId,
// }{ // EnableVis: true,
// ConvUID: "",
// UserInput: queryText,
// UserName: username,
// ChatMode: "",
// AppCode: "",
// Temperature: 0.5,
// MaxNewTokens: 4000,
// SelectParam: "",
// ModelName: ds.modelName,
// Incremental: false,
// SysCode: "",
// PromptCode: "",
// ExtInfo: map[string]interface{}{
// "space_id": ds.spaceID,
// //"k": count,
// },
// } // }
// body, err := json.Marshal(chatReq) // body, err := json.Marshal(chatReq)
...@@ -713,35 +1788,21 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -713,35 +1788,21 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// } // }
// // Parse response // // Parse response
// var chatResp struct { // var chatResp ChatResponse
// Success bool `json:"success"`
// Data struct {
// Answer []struct {
// Content string `json:"content"`
// DocID int `json:"doc_id"`
// Score float64 `json:"score"`
// Metadata map[string]interface{} `json:"metadata_map"`
// } `json:"answer"`
// } `json:"data"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { // if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
// return nil, fmt.Errorf("decode chat response: %w", err) // body, _ := io.ReadAll(resp.Body)
// return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(body))
// } // }
// // Convert response to ai.Document
// var docs []*ai.Document // var docs []*ai.Document
// for _, doc := range chatResp.Data.Answer { // if len(chatResp.Choices) > 0 {
// metadata := doc.Metadata // content := chatResp.Choices[0].Message.Content
// if metadata == nil { // metadata := map[string]interface{}{
// metadata = make(map[string]interface{}) // util.UserIdKey: userId,
// } // util.UserNameKey: userName,
// // Ensure metadata includes user_id and username.
// if _, ok := metadata["user_id"]; !ok {
// metadata["user_id"] = "user123"
// }
// if _, ok := metadata["username"]; !ok {
// metadata["username"] = username
// } // }
// aiDoc := ai.DocumentFromText(doc.Content, metadata) // aiDoc := ai.DocumentFromText(content, metadata)
// docs = append(docs, aiDoc) // docs = append(docs, aiDoc)
// } // }
...@@ -749,149 +1810,3 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -749,149 +1810,3 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// Documents: docs, // Documents: docs,
// }, nil // }, nil
// } // }
// ChatRequest 定义请求结构体,匹配单元测试的 curl 请求
type ChatRequest struct {
Model string `json:"model"`
Messages string `json:"messages"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
N int `json:"n"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
RepetitionPenalty float64 `json:"repetition_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
ChatMode string `json:"chat_mode"`
ChatParam string `json:"chat_param"`
EnableVis bool `json:"enable_vis"`
}
// ChatResponse 定义响应结构体,匹配单元测试的 API 响应
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent interface{} `json:"reasoning_content"`
} `json:"message"`
FinishReason interface{} `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
// Assuming ai.Part has a Text() method or Text field to get the string content
func partsToString(parts []*ai.Part) string {
var texts []string
for _, part := range parts {
// Adjust this based on the actual ai.Part structure
// If ai.Part has a Text() method:
texts = append(texts, part.Text)
// OR if ai.Part has a Text field:
// texts = append(texts, part.Text)
}
return strings.Join(texts, " ")
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Format query for retrieval.
queryContent := partsToString(req.Query.Content)
queryText := fmt.Sprintf("Search for: %s", queryContent)
if req.Query.Metadata == nil {
// If ok, we don't use the User struct since the requirement is to error on non-nil
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
}
for k, v := range req.Query.Metadata {
fmt.Println("k", k, "v", v)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
// Prepare request for chat completions endpoint.
url := fmt.Sprintf("%s/api/v2/chat/completions", ds.client.BaseURL)
chatReq := ChatRequest{
Model: ds.modelName,
Messages: queryText,
Temperature: 0.7,
TopP: 1,
TopK: -1,
N: 1,
MaxTokens: 0,
Stream: false,
RepetitionPenalty: 1,
FrequencyPenalty: 0,
PresencePenalty: 0,
ChatMode: "chat_knowledge",
ChatParam: userId,
EnableVis: true,
}
body, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("marshal chat request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("create chat request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send chat request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var chatResp ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(body))
}
// Convert response to ai.Document
var docs []*ai.Document
if len(chatResp.Choices) > 0 {
content := chatResp.Choices[0].Message.Content
metadata := map[string]interface{}{
util.UserIdKey: userId,
util.UserNameKey: userName,
}
aiDoc := ai.DocumentFromText(content, metadata)
docs = append(docs, aiDoc)
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
...@@ -30,6 +30,7 @@ import ( ...@@ -30,6 +30,7 @@ import (
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/milvus-io/milvus-sdk-go/v2/client" "github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity" "github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/rs/zerolog/log"
"github.com/wade-liwei/agentchat/util" "github.com/wade-liwei/agentchat/util"
) )
...@@ -38,235 +39,818 @@ const provider = "milvus" ...@@ -38,235 +39,818 @@ const provider = "milvus"
// Field names for Milvus schema. // Field names for Milvus schema.
const ( const (
idField = "id" idField = "id"
vectorField = "vector" vectorField = "vector"
textField = "text" textField = "text"
metadataField = "metadata" metadataField = "metadata"
) )
// Milvus holds configuration for the plugin. // Milvus holds configuration for the plugin.
type Milvus struct { type Milvus struct {
// Milvus server address (host:port, e.g., "localhost:19530"). // Milvus server address (host:port, e.g., "localhost:19530").
// Defaults to MILVUS_ADDRESS environment variable. // Defaults to MILVUS_ADDRESS environment variable.
Addr string Addr string
// Username for authentication. // Username for authentication.
// Defaults to MILVUS_USERNAME. // Defaults to MILVUS_USERNAME.
Username string Username string
// Password for authentication. // Password for authentication.
// Defaults to MILVUS_PASSWORD. // Defaults to MILVUS_PASSWORD.
Password string Password string
// Token for authentication (alternative to username/password). // Token for authentication (alternative to username/password).
// Defaults to MILVUS_TOKEN. // Defaults to MILVUS_TOKEN.
Token string Token string
client client.Client // Milvus client. client client.Client // Milvus client.
mu sync.Mutex // Mutex to control access. mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized. initted bool // Whether the plugin has been initialized.
} }
// Name returns the plugin name. // Name returns the plugin name.
func (m *Milvus) Name() string { func (m *Milvus) Name() string {
return provider return provider
} }
// Init initializes the Milvus plugin. // Init initializes the Milvus plugin.
func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) { func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if m == nil { log.Info().Str("method", "Milvus.Init").Msg("Initializing Milvus plugin")
m = &Milvus{} if m == nil {
} m = &Milvus{}
}
m.mu.Lock()
defer m.mu.Unlock() m.mu.Lock()
defer func() { defer m.mu.Unlock()
if err != nil { defer func() {
err = fmt.Errorf("milvus.Init: %w", err) if err != nil {
} log.Error().Err(err).Str("method", "Milvus.Init").Msg("Initialization failed")
}() err = fmt.Errorf("milvus.Init: %w", err)
} else {
if m.initted { log.Info().Str("method", "Milvus.Init").Msg("Initialization successful")
return errors.New("plugin already initialized") }
} }()
// Load configuration. if m.initted {
addr := m.Addr return errors.New("plugin already initialized")
if addr == "" { }
addr = os.Getenv("MILVUS_ADDRESS")
} // Load configuration.
if addr == "" { addr := m.Addr
return errors.New("milvus address required") if addr == "" {
} addr = os.Getenv("MILVUS_ADDRESS")
}
username := m.Username if addr == "" {
if username == "" { return errors.New("milvus address required")
username = os.Getenv("MILVUS_USERNAME") }
}
password := m.Password username := m.Username
if password == "" { if username == "" {
password = os.Getenv("MILVUS_PASSWORD") username = os.Getenv("MILVUS_USERNAME")
} }
token := m.Token password := m.Password
if token == "" { if password == "" {
token = os.Getenv("MILVUS_TOKEN") password = os.Getenv("MILVUS_PASSWORD")
} }
token := m.Token
// Initialize Milvus client. if token == "" {
config := client.Config{ token = os.Getenv("MILVUS_TOKEN")
Address: addr, }
Username: username,
Password: password, // Initialize Milvus client.
APIKey: token, config := client.Config{
} Address: addr,
client, err := client.NewClient(ctx, config) Username: username,
if err != nil { Password: password,
return fmt.Errorf("failed to initialize Milvus client: %v", err) APIKey: token,
} }
client, err := client.NewClient(ctx, config)
m.client = client if err != nil {
m.initted = true return fmt.Errorf("failed to initialize Milvus client: %v", err)
return nil }
m.client = client
m.initted = true
return nil
} }
// CollectionConfig holds configuration for an indexer/retriever pair. // CollectionConfig holds configuration for an indexer/retriever pair.
type CollectionConfig struct { type CollectionConfig struct {
// Milvus collection name. Must not be empty. // Milvus collection name. Must not be empty.
Collection string Collection string
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002). // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
Dimension int Dimension int
// Embedder for generating vectors. // Embedder for generating vectors.
Embedder ai.Embedder Embedder ai.Embedder
// Embedder options. // Embedder options.
EmbedderOptions any EmbedderOptions any
} }
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection. // DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) { func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
if cfg.Embedder == nil { log.Info().
return nil, nil, errors.New("milvus: Embedder required") Str("method", "DefineIndexerAndRetriever").
} Str("collection", cfg.Collection).
if cfg.Collection == "" { Int("dimension", cfg.Dimension).
return nil, nil, errors.New("milvus: collection name required") Msg("Defining indexer and retriever")
}
if cfg.Dimension <= 0 { if cfg.Embedder == nil {
return nil, nil, errors.New("milvus: dimension must be positive") log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Embedder required")
} return nil, nil, errors.New("milvus: Embedder required")
}
m := genkit.LookupPlugin(g, provider) if cfg.Collection == "" {
if m == nil { log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Collection name required")
return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?") return nil, nil, errors.New("milvus: collection name required")
} }
milvus := m.(*Milvus) if cfg.Dimension <= 0 {
log.Error().Str("method", "DefineIndexerAndRetriever").Int("dimension", cfg.Dimension).Msg("Dimension must be positive")
ds, err := milvus.newDocStore(ctx, &cfg) return nil, nil, errors.New("milvus: dimension must be positive")
if err != nil { }
return nil, nil, err
} m := genkit.LookupPlugin(g, provider)
if m == nil {
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index) log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Milvus plugin not found")
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve) return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
return indexer, retriever, nil }
milvus := m.(*Milvus)
ds, err := milvus.newDocStore(ctx, &cfg)
if err != nil {
log.Error().Err(err).Str("method", "DefineIndexerAndRetriever").Str("collection", cfg.Collection).Msg("Failed to create doc store")
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
log.Info().Str("method", "DefineIndexerAndRetriever").Str("collection", cfg.Collection).Msg("Indexer and retriever defined successfully")
return indexer, retriever, nil
} }
// docStore defines an Indexer and a Retriever. // docStore defines an Indexer and a Retriever.
type docStore struct { type docStore struct {
client client.Client client client.Client
collection string collection string
dimension int dimension int
embedder ai.Embedder embedder ai.Embedder
embedderOptions map[string]interface{} embedderOptions map[string]interface{}
} }
// // newDocStore creates a docStore. // newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil { log.Info().
// return nil, errors.New("milvus.Init not called") Str("method", "Milvus.newDocStore").
// } Str("collection", cfg.Collection).
Int("dimension", cfg.Dimension).
Msg("Creating new doc store")
if m.client == nil {
log.Error().Str("method", "Milvus.newDocStore").Msg("Milvus client not initialized")
return nil, errors.New("milvus.Init not called")
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to check collection")
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema with textField as primary key, plus user_id and username fields.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true,
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
{
Name: "user_id",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128",
},
},
{
Name: "username",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128",
},
},
},
}
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to create collection")
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index for vectorField.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to create HNSW index")
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
log.Error().Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msgf("Failed to create index: %s",err.Error())
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to load collection")
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
log.Error().
Str("method", "Milvus.newDocStore").
Str("type", fmt.Sprintf("%T", cfg.EmbedderOptions)).
Msg("EmbedderOptions must be a map[string]interface{}")
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
log.Info().Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Doc store created successfully")
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
}
// // Check/create collection. // Indexer returns the indexer for a collection.
// exists, err := m.client.HasCollection(ctx, cfg.Collection) func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
// if err != nil { log.Info().Str("method", "Indexer").Str("collection", collection).Msg("Looking up indexer")
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err) indexer := genkit.LookupIndexer(g, provider, collection)
if indexer == nil {
log.Warn().Str("method", "Indexer").Str("collection", collection).Msg("Indexer not found")
}
return indexer
}
// Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
log.Info().Str("method", "Retriever").Str("collection", collection).Msg("Looking up retriever")
retriever := genkit.LookupRetriever(g, provider, collection)
if retriever == nil {
log.Warn().Str("method", "Retriever").Str("collection", collection).Msg("Retriever not found")
}
return retriever
}
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
log.Info().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("documents", len(req.Documents)).
Msg("Starting index operation")
if len(req.Documents) == 0 {
log.Debug().Str("method", "docStore.Index").Str("collection", ds.collection).Msg("No documents to index")
return nil
}
// Embed documents.
ereq := &ai.EmbedRequest{
Input: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Index").Str("collection", ds.collection).Msg("Embedding failed")
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
log.Error().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("embeddings", len(eres.Embeddings)).
Int("documents", len(req.Documents)).
Msg("Mismatch in embedding and document count")
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
doc := req.Documents[i]
if doc.Metadata == nil {
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Document metadata is nil")
return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", doc.Metadata)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := doc.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Missing username in metadata")
return fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := doc.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Missing user_id in metadata")
return fmt.Errorf("req.Query.Metadata must provide user_id key")
}
var sb strings.Builder
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = emb.Embedding
row["text"] = text
row["user_id"] = userId
row["username"] = userName
row["metadata"] = metadata
rows = append(rows, row)
log.Debug().
Str("method", "docStore.Index").
Int("index", i).
Str("collection", ds.collection).
Int("vector_length", len(emb.Embedding)).
Str("text", text).
Str("user_id", userId).
Str("username", userName).
Interface("metadata", metadata).
Msg("Prepared row for insertion")
}
log.Info().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("rows", len(rows)).
Msg("Inserting rows into Milvus")
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Index").Str("collection", ds.collection).Msg("Failed to insert rows")
return fmt.Errorf("milvus insert rows failed: %w", err)
}
log.Info().Str("method", "docStore.Index").Str("collection", ds.collection).Int("rows", len(rows)).Msg("Index operation completed successfully")
return nil
}
// RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct {
Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
log.Info().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Msg("Starting retrieve operation")
if req.Query.Metadata == nil {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Query metadata is nil")
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Missing username in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Missing user_id in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
log.Error().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("options_type", fmt.Sprintf("%T", req.Options)).
Msg("Invalid options type")
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
log.Error().Str("method", "docStore.Retrieve").Str("metric_type", metricTypeStr).Msg("Unsupported metric type")
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Embedding failed")
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("No embeddings generated")
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Failed to create HNSW search parameters")
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Define filter expression for user_id
expr := fmt.Sprintf("user_id == %q", userId)
log.Debug().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("user_id", userId).
Str("metric_type", metricTypeStr).
Int("count", count).
Msg("Performing vector search")
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{},
expr,
[]string{textField, metadataField},
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Search failed")
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists.
if textCol == nil {
log.Error().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("field", textField).
Msg("Text column not found in search results")
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// Iterate over rows.
for i := 0; i < result.ResultCount; i++ {
// Get text value.
text, err := textCol.GetAsString(i)
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to parse text")
continue
}
// Get metadata value (optional).
var metadata map[string]interface{}
if metaCol != nil {
metaStr, err := metaCol.GetAsString(i)
if err == nil && metaStr != "" {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to parse metadata")
continue
}
} else if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to get metadata string")
}
}
// Ensure metadata includes user_id and username from query.
if metadata == nil {
metadata = make(map[string]interface{})
}
metadata[util.UserIdKey] = userId
metadata[util.UserNameKey] = userName
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
log.Debug().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Str("text", text).
Interface("metadata", metadata).
Msg("Processed search result")
}
}
log.Info().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("documents", len(docs)).
Msg("Retrieve operation completed successfully")
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// // Package milvus provides a Genkit plugin for Milvus vector database using milvus-sdk-go.
// package milvus
// import (
// "context"
// "encoding/json"
// "errors"
// "fmt"
// "os"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/milvus-io/milvus-sdk-go/v2/client"
// "github.com/milvus-io/milvus-sdk-go/v2/entity"
// "github.com/wade-liwei/agentchat/util"
// )
// // The provider used in the registry.
// const provider = "milvus"
// // Field names for Milvus schema.
// const (
// idField = "id"
// vectorField = "vector"
// textField = "text"
// metadataField = "metadata"
// )
// // Milvus holds configuration for the plugin.
// type Milvus struct {
// // Milvus server address (host:port, e.g., "localhost:19530").
// // Defaults to MILVUS_ADDRESS environment variable.
// Addr string
// // Username for authentication.
// // Defaults to MILVUS_USERNAME.
// Username string
// // Password for authentication.
// // Defaults to MILVUS_PASSWORD.
// Password string
// // Token for authentication (alternative to username/password).
// // Defaults to MILVUS_TOKEN.
// Token string
// client client.Client // Milvus client.
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the plugin name.
// func (m *Milvus) Name() string {
// return provider
// }
// // Init initializes the Milvus plugin.
// func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
// if m == nil {
// m = &Milvus{}
// } // }
// if !exists {
// // Define schema.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: idField,
// DataType: entity.FieldTypeInt64,
// PrimaryKey: true,
// AutoID: true,
// },
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "65535",
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) // m.mu.Lock()
// defer m.mu.Unlock()
// defer func() {
// if err != nil { // if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err) // err = fmt.Errorf("milvus.Init: %w", err)
// } // }
// }()
// // Create HNSW index. // if m.initted {
// index, err := entity.NewIndexHNSW( // return errors.New("plugin already initialized")
// entity.L2, // }
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false) // // Load configuration.
// if err != nil { // addr := m.Addr
// return nil, fmt.Errorf("failed to create index: %v", err) // if addr == "" {
// } // addr = os.Getenv("MILVUS_ADDRESS")
// }
// if addr == "" {
// return errors.New("milvus address required")
// } // }
// // Load collection. // username := m.Username
// err = m.client.LoadCollection(ctx, cfg.Collection, false) // if username == "" {
// username = os.Getenv("MILVUS_USERNAME")
// }
// password := m.Password
// if password == "" {
// password = os.Getenv("MILVUS_PASSWORD")
// }
// token := m.Token
// if token == "" {
// token = os.Getenv("MILVUS_TOKEN")
// }
// // Initialize Milvus client.
// config := client.Config{
// Address: addr,
// Username: username,
// Password: password,
// APIKey: token,
// }
// client, err := client.NewClient(ctx, config)
// if err != nil { // if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err) // return fmt.Errorf("failed to initialize Milvus client: %v", err)
// } // }
// // Convert EmbedderOptions to map[string]interface{}. // m.client = client
// var embedderOptions map[string]interface{} // m.initted = true
// if cfg.EmbedderOptions != nil { // return nil
// opts, ok := cfg.EmbedderOptions.(map[string]interface{}) // }
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions) // // CollectionConfig holds configuration for an indexer/retriever pair.
// } // type CollectionConfig struct {
// embedderOptions = opts // // Milvus collection name. Must not be empty.
// } else { // Collection string
// embedderOptions = make(map[string]interface{}) // // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
// Dimension int
// // Embedder for generating vectors.
// Embedder ai.Embedder
// // Embedder options.
// EmbedderOptions any
// }
// // DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
// func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
// if cfg.Embedder == nil {
// return nil, nil, errors.New("milvus: Embedder required")
// }
// if cfg.Collection == "" {
// return nil, nil, errors.New("milvus: collection name required")
// }
// if cfg.Dimension <= 0 {
// return nil, nil, errors.New("milvus: dimension must be positive")
// } // }
// return &docStore{ // m := genkit.LookupPlugin(g, provider)
// client: m.client, // if m == nil {
// collection: cfg.Collection, // return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
// dimension: cfg.Dimension, // }
// embedder: cfg.Embedder, // milvus := m.(*Milvus)
// embedderOptions: embedderOptions,
// }, nil // ds, err := milvus.newDocStore(ctx, &cfg)
// if err != nil {
// return nil, nil, err
// }
// indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
// retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
// return indexer, retriever, nil
// } // }
// newDocStore creates a docStore. // // docStore defines an Indexer and a Retriever.
// type docStore struct {
// client client.Client
// collection string
// dimension int
// embedder ai.Embedder
// embedderOptions map[string]interface{}
// }
// // newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { // func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil { // if m.client == nil {
// return nil, errors.New("milvus.Init not called") // return nil, errors.New("milvus.Init not called")
...@@ -278,16 +862,10 @@ type docStore struct { ...@@ -278,16 +862,10 @@ type docStore struct {
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err) // return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
// } // }
// if !exists { // if !exists {
// // Define schema with textField as primary key for unique constraint. // // Define schema with textField as primary key, plus user_id and username fields.
// schema := &entity.Schema{ // schema := &entity.Schema{
// CollectionName: cfg.Collection, // CollectionName: cfg.Collection,
// Fields: []*entity.Field{ // Fields: []*entity.Field{
// // {
// // Name: idField, // Optional non-primary ID field
// // DataType: entity.FieldTypeInt64,
// // //AutoID: true,
// // // No PrimaryKey or AutoID, as textField is the primary key
// // },
// { // {
// Name: vectorField, // Name: vectorField,
// DataType: entity.FieldTypeFloatVector, // DataType: entity.FieldTypeFloatVector,
...@@ -300,50 +878,36 @@ type docStore struct { ...@@ -300,50 +878,36 @@ type docStore struct {
// DataType: entity.FieldTypeVarChar, // DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field // PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{ // TypeParams: map[string]string{
// "max_length": "65535", // Maximum length for VARCHAR, adjust if needed // "max_length": "65535", // Maximum length for VARCHAR
// }, // },
// }, // },
// { // {
// Name: metadataField, // Name: metadataField,
// DataType: entity.FieldTypeJSON, // DataType: entity.FieldTypeJSON,
// }, // },
// {
// Name: "user_id",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for user_id
// },
// },
// {
// Name: "username",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for username
// },
// },
// }, // },
// } // }
// // Alternative: Remove idField if not needed
// /*
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{
// "max_length": "65535",
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// */
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) // err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
// if err != nil { // if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err) // return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// } // }
// // Create HNSW index. // // Create HNSW index for vectorField.
// index, err := entity.NewIndexHNSW( // index, err := entity.NewIndexHNSW(
// entity.L2, // entity.L2,
// 8, // M // 8, // M
...@@ -386,220 +950,118 @@ type docStore struct { ...@@ -386,220 +950,118 @@ type docStore struct {
// }, nil // }, nil
// } // }
// newDocStore creates a docStore. // // Indexer returns the indexer for a collection.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { // func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
if m.client == nil { // return genkit.LookupIndexer(g, provider, collection)
return nil, errors.New("milvus.Init not called") // }
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema with textField as primary key, plus user_id and username fields.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{
"max_length": "65535", // Maximum length for VARCHAR
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
{
Name: "user_id",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128", // Reasonable length for user_id
},
},
{
Name: "username",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128", // Reasonable length for username
},
},
},
}
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index for vectorField.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
}
// Indexer returns the indexer for a collection. // // Retriever returns the retriever for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer { // func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupIndexer(g, provider, collection) // return genkit.LookupRetriever(g, provider, collection)
} // }
// Retriever returns the retriever for a collection. // /*
func Retriever(g *genkit.Genkit, collection string) ai.Retriever { // 更新 删除 很少用到;
return genkit.LookupRetriever(g, provider, collection) // */
}
/* // // Index implements the Indexer.Index method.
更新 删除 很少用到; // func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
*/ // if len(req.Documents) == 0 {
// return nil
// }
// Index implements the Indexer.Index method. // // Embed documents.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { // ereq := &ai.EmbedRequest{
if len(req.Documents) == 0 { // Input: req.Documents,
return nil // Options: ds.embedderOptions,
} // }
// eres, err := ds.embedder.Embed(ctx, ereq)
// Embed documents. // if err != nil {
ereq := &ai.EmbedRequest{ // return fmt.Errorf("milvus index embedding failed: %w", err)
Input: req.Documents, // }
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
doc := req.Documents[i]
if doc.Metadata == nil {
// If ok, we don't use the User struct since the requirement is to error on non-nil
return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := doc.Metadata[util.UserNameKey].(string)
if !ok {
return fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := doc.Metadata[util.UserIdKey].(string)
if !ok {
return fmt.Errorf("req.Query.Metadata must provide user_id key")
}
var sb strings.Builder
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = emb.Embedding // []float32
row["text"] = text
row["user_id"] = userId
row["username"] = userName
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
rows = append(rows, row)
// Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q,userId=%s,username=%s,metadata=%v\n", i, len(emb.Embedding), text, userId, userName, metadata)
}
// Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
return fmt.Errorf("milvus insert rows failed: %w", err)
}
return nil
}
// RetrieverOptions for Milvus retrieval. // // Validate embedding count matches document count.
type RetrieverOptions struct { // if len(eres.Embeddings) != len(req.Documents) {
Count int `json:"count,omitempty"` // Max documents to retrieve. // return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). // }
}
// // Prepare row-based data.
// var rows []interface{}
// for i, emb := range eres.Embeddings {
// doc := req.Documents[i]
// if doc.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil
// return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := doc.Metadata[util.UserNameKey].(string)
// if !ok {
// return fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := doc.Metadata[util.UserIdKey].(string)
// if !ok {
// return fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// metadata := doc.Metadata
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// // Create row with explicit metadata field.
// row := make(map[string]interface{})
// row["vector"] = emb.Embedding // []float32
// row["text"] = text
// row["user_id"] = userId
// row["username"] = userName
// row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
// rows = append(rows, row)
// // Debug: Log row contents.
// fmt.Printf("Row %d: vector_len=%d, text=%q,userId=%s,username=%s,metadata=%v\n", i, len(emb.Embedding), text, userId, userName, metadata)
// }
// // Debug: Log total rows.
// fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// // Insert rows into Milvus.
// _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
// if err != nil {
// return fmt.Errorf("milvus insert rows failed: %w", err)
// }
// return nil
// }
// // RetrieverOptions for Milvus retrieval.
// type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve.
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// }
// // Retrieve implements the Retriever.Retrieve method. // // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { // func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// if req.Query.Metadata == nil { // if req.Query.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil // return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
// return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// } // }
// // Extract username and user_id from req.Query.Metadata // // Extract username and user_id from req.Query.Metadata
// userName, ok := req.Query.Metadata[util.UserNameKey].(string) // userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// if !ok { // if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key") // return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// } // }
// userId, ok := req.Query.Metadata[util.UserIdKey].(string) // userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// if !ok { // if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key") // return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// } // }
// count := 3 // Default. // count := 3 // Default.
...@@ -648,12 +1110,15 @@ type RetrieverOptions struct { ...@@ -648,12 +1110,15 @@ type RetrieverOptions struct {
// return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err) // return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// } // }
// // Define filter expression for user_id
// expr := fmt.Sprintf("user_id == %q", userId)
// // Perform vector search to get IDs, text, and metadata. // // Perform vector search to get IDs, text, and metadata.
// results, err := ds.client.Search( // results, err := ds.client.Search(
// ctx, // ctx,
// ds.collection, // ds.collection,
// []string{}, // partitions // []string{}, // partitions
// "", // expr (TODO: add metadata filter if needed) // expr, // Filter by user_id
// []string{textField, metadataField}, // Output fields: text and metadata // []string{textField, metadataField}, // Output fields: text and metadata
// []entity.Vector{queryVector}, // []entity.Vector{queryVector},
// vectorField, // vectorField,
...@@ -707,8 +1172,12 @@ type RetrieverOptions struct { ...@@ -707,8 +1172,12 @@ type RetrieverOptions struct {
// } // }
// } // }
// // Print text and metadata in a format similar to insertion debug log. // // Ensure metadata includes user_id and username from query
// // fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata) // if metadata == nil {
// metadata = make(map[string]interface{})
// }
// metadata[util.UserIdKey] = userId
// metadata[util.UserNameKey] = userName
// // Create document. // // Create document.
// doc := ai.DocumentFromText(text, metadata) // doc := ai.DocumentFromText(text, metadata)
...@@ -720,145 +1189,3 @@ type RetrieverOptions struct { ...@@ -720,145 +1189,3 @@ type RetrieverOptions struct {
// Documents: docs, // Documents: docs,
// }, nil // }, nil
// } // }
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
if req.Query.Metadata == nil {
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Define filter expression for user_id
expr := fmt.Sprintf("user_id == %q", userId)
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
expr, // Filter by user_id
[]string{textField, metadataField}, // Output fields: text and metadata
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists.
if textCol == nil {
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// Iterate over rows (assuming columns have same length).
for i := 0; i < result.ResultCount; i++ {
// Get text value.
text, err := textCol.GetAsString(i)
if err != nil {
fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
continue
}
// Get metadata value (optional, as metadata column may be missing).
var metadata map[string]interface{}
if metaCol != nil {
metaStr, err := metaCol.GetAsString(i)
if err == nil && metaStr != "" {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
continue
}
} else if err != nil {
fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
}
}
// Ensure metadata includes user_id and username from query
if metadata == nil {
metadata = make(map[string]interface{})
}
metadata[util.UserIdKey] = userId
metadata[util.UserNameKey] = userName
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment