Commit 8604b187 authored by Wade's avatar Wade

plugins use new log lib

parent 16d5d242
{"level":"info","pid":11627,"time":1749036605,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":11627,"time":1749036651,"caller":"/Users/wade/project/wuban/agentchat/main.go:229","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":11627,"time":1749036651,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"qaAsJson--------{\"ID\":16,\"CreatedAt\":\"2025-06-04T09:50:02.837091Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":null,\"Summary\":null,\"To\":\"Bob\",\"ToID\":\"user4567\"}"}
{"level":"info","pid":11627,"time":1749036656,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":11627,"time":1749036665,"caller":"/Users/wade/project/wuban/agentchat/main.go:294","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":12824,"time":1749036924,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":12824,"time":1749036930,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":12824,"time":1749036931,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":27,\"CreatedAt\":\"2025-06-04T11:30:53.508295Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"Well now, if Paris is the heart of France, and the US loves its coffee, then you're probably wondering about the UK. The capital of the UK is London, a truly grand city!\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":12824,"time":1749036933,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":12824,"time":1749036937,"caller":"/Users/wade/project/wuban/agentchat/main.go:320","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":18861,"time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/log.go:69","message":"This message appears when log level set to Debug or Info"}
{"level":"info","pid":18861,"method":"DeepSeek.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/deepseek/deepseek.go:91","message":"Initializing DeepSeek plugin"}
{"level":"info","pid":18861,"method":"DeepSeek.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/deepseek/deepseek.go:104","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"Milvus.Init","time":1749038762,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:75","message":"Initializing Milvus plugin"}
{"level":"info","pid":18861,"method":"Milvus.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:87","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"GraphKnowledge.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:299","message":"Initializing GraphKnowledge plugin"}
{"level":"info","pid":18861,"method":"NewClient","ip":"54.92.111.204","port":5670,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:93","message":"Creating new GraphRAG client"}
{"level":"info","pid":18861,"method":"GraphKnowledge.Init","time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:311","message":"Initialization successful"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","collection":"chatRag1","dimension":768,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:152","message":"Defining indexer and retriever"}
{"level":"info","pid":18861,"method":"Milvus.newDocStore","collection":"chatRag1","dimension":768,"time":1749038764,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:201","message":"Creating new doc store"}
{"level":"info","pid":18861,"method":"Milvus.newDocStore","collection":"chatRag1","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:302","message":"Doc store created successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","collection":"chatRag1","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:182","message":"Indexer and retriever defined successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:357","message":"Defining indexer and retriever"}
{"level":"info","pid":18861,"method":"GraphKnowledge.newDocStore","space_id":"","model_name":"Qwen/Qwen2.5-Coder-32B-Instruct","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:393","message":"Creating new doc store"}
{"level":"info","pid":18861,"method":"GraphKnowledge.newDocStore","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:399","message":"Doc store created successfully"}
{"level":"info","pid":18861,"method":"DefineIndexerAndRetriever","time":1749038765,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:376","message":"Indexer and retriever defined successfully"}
{"level":"info","pid":18861,"time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":18861,"time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":28,\"CreatedAt\":\"2025-06-04T11:35:33.142254Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"I'm sorry, but the provided context doesn't contain information about the capital of the UK.\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","time":1749038774,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:450","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","documents":2,"time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:640","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","time":1749038778,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:755","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","documents":1,"time":1749038786,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:892","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038786,"caller":"/Users/wade/project/wuban/agentchat/main.go:320","message":"promptInput.Graph : 知识库中提供的内容不足以回答此问题\n\n<references title=\"References\" references=\"[]\" />\n"}
{"level":"info","pid":18861,"time":1749038813,"caller":"/Users/wade/project/wuban/agentchat/main.go:255","message":"input--------{\"content\":\"What is the capital of UK?\",\"model\":\"gpt-3.5-turbo\",\"apiKey\":\"sk-1234567890abcdef\",\"from\":\"Alice\",\"from_id\":\"user123\",\"to\":\"Bob\",\"to_id\":\"user456\"}"}
{"level":"info","pid":18861,"time":1749038814,"caller":"/Users/wade/project/wuban/agentchat/main.go:281","message":"qaAsJson--------{\"ID\":29,\"CreatedAt\":\"2025-06-04T12:06:16.535774Z\",\"FromID\":\"user123\",\"From\":\"Alice\",\"Question\":\"What is the capital of UK?\",\"Answer\":\"I'm sorry, but the provided information does not contain the answer to your question about the capital of the UK.\\n\",\"Summary\":\"\",\"To\":\"Bob\",\"ToID\":\"user456\"}"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","time":1749038814,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:450","message":"Starting retrieve operation"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","collection":"chatRag1","documents":2,"time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/plugins/milvus/milvus.go:640","message":"Retrieve operation completed successfully"}
{"level":"info","pid":18861,"time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/main.go:307","message":"promptInput.Context: Paris is the capital of France?\nUSA is the largest importer of coffee?\n"}
{"level":"info","pid":18861,"method":"docStore.Retrieve","space_id":"","time":1749038815,"caller":"/Users/wade/project/wuban/agentchat/plugins/graphrag/graph.go:755","message":"Starting retrieve operation"}
{"level":"fatal","pid":18861,"time":1749038816,"caller":"/Users/wade/project/wuban/agentchat/main.go:362","message":"Server failed: failed to shutdown server: context canceled"}
openapi: 3.0.4
info:
title: Genkit Chat API
description: API for interacting with a chat endpoint powered by Genkit.
description: API for interacting with chat and indexing endpoints powered by Genkit.
version: 0.1.0
paths:
/index/document:
......@@ -44,12 +44,13 @@ paths:
content:
application/json:
schema:
type: object
properties:
id:
type: integer
description: The ID of the stored record
example: 1
$ref: '#/components/schemas/Response'
examples:
success:
value:
data: '{"id": 1}'
code: 200
msg: "Milvus index data stored successfully"
/index/graph:
post:
summary: Store GraphRAG index data
......@@ -90,12 +91,13 @@ paths:
content:
application/json:
schema:
type: object
properties:
id:
type: integer
description: The ID of the stored record
example: 1
$ref: '#/components/schemas/Response'
examples:
success:
value:
data: '{"id": 1}'
code: 200
msg: "GraphRAG index data stored successfully"
/chat:
post:
summary: Send a chat message
......@@ -145,11 +147,31 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
examples:
success:
value:
data: "The capital of the UK is London."
code: 200
msg: "Chat response generated successfully"
components:
schemas:
Response:
type: object
properties:
response:
data:
type: string
description: The response from the chat workflow
example: "The capital of the UK is London."
components:
schemas: {}
description: The response data, typically a JSON string or message
example: '{"id": 1}'
code:
type: integer
description: The response code (200 for success, 400 for invalid input, etc.)
example: 200
msg:
type: string
description: A message describing the result
example: "Milvus index data stored successfully"
required:
- data
- code
- msg
\ No newline at end of file
......@@ -36,8 +36,8 @@ func loggingInit() {
// // Configure log rotation with lumberjack
lumberjackLogger := &lumberjack.Logger{
Filename: "/var/log/agent_chat.log",
//Filename: "./tweet.log",
//Filename: "/var/log/agent_chat.log",
Filename: "agent_chat.log",
MaxSize: 1, // Max size in megabytes before log is rotated
MaxBackups: 3, // Max number of old log files to retain
MaxAge: 28, // Max number of days to retain old log files
......
......@@ -52,8 +52,32 @@ type GraphInput struct {
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type simpleQaPromptInput struct {
Query string `json:"query"`
Context string `json:"context"`
Graph string `json:"graph"`
Summary string `json:"summary"`
}
// const simpleQaPromptTemplate = `
// You're a helpful agent that answers the user's questions with a tone and style shaped by the specified personality.
// Here is the user's query: {{query}}
// Here is the context you should use: {{context}} from Milvus
// Graph context: {{graph}}
// Previous conversation summary: {{summary}}
// Personality to adopt: {{personality}}
// Please provide a response that aligns with the given personality while leveraging the provided context, graph, and conversation summary.
// `
const simpleQaPromptTemplate = `
You're a helpful agent that answers the user's questions with a tone and style shaped by the specified personality.
You're a helpful agent that answers the user's questions based on the provided context.
Here is the user's query: {{query}}
......@@ -63,11 +87,13 @@ Graph context: {{graph}}
Previous conversation summary: {{summary}}
Personality to adopt: {{personality}}
Please provide a response that aligns with the given personality while leveraging the provided context, graph, and conversation summary.
Instructions:
- If the query is related to a character's personality, adopt the tone and style specified in the Personality context, and generate a response using the Milvus and Graph contexts to inform the personality-driven content.
- For all other queries, provide a clear and accurate response using the Milvus and Graph contexts, without emphasizing the Personality context.
- Ensure responses leverage the Previous conversation summary when relevant.
`
func main() {
debug := flag.Bool("debug", false, "sets log level to debug")
......@@ -220,7 +246,10 @@ func main() {
inputAsJson, err := json.Marshal(input)
if err != nil {
return "", err
return Response{
Code: 500,
Msg: fmt.Sprintf("json.Marshal: %w", err),
}, nil
}
log.Info().Msgf("input--------%s", string(inputAsJson))
......@@ -234,13 +263,19 @@ func main() {
})
if err != nil {
return "", err
return Response{
Code: 500,
Msg: fmt.Sprintf("WriteAndGetLatestQA: %w", err),
}, nil
}
qaAsJson, err := json.Marshal(lastQa)
if err != nil {
return "", err
return Response{
Code: 500,
Msg: fmt.Sprintf("json.Marshal(lastQa): %w", err),
}, nil
}
log.Info().Msgf("qaAsJson--------%s", string(qaAsJson))
......@@ -299,7 +334,6 @@ func main() {
return Response{
Data: resp.Text(),
Code: 200,
Msg: fmt.Sprintf("Document indexed successfully, docname %s", resDocName),
}, nil
})
......@@ -329,12 +363,6 @@ func main() {
}
}
type simpleQaPromptInput struct {
Query string `json:"query"`
Context string `json:"context"`
Graph string `json:"graph"`
Summary string `json:"summary"`
}
type Response struct {
Data string `json:"data"`
......
......@@ -3,12 +3,12 @@ package deepseek
import (
"context"
"fmt"
"log"
"strings"
"sync"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/rs/zerolog/log"
deepseek "github.com/cohesion-org/deepseek-go"
)
......@@ -17,15 +17,6 @@ const provider = "deepseek"
var (
mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
// "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
// "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
// "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
// "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
// "granite3.3", "command-a", "command-r7b-arabic",
// }
roleMapping = map[ai.Role]string{
ai.RoleUser: deepseek.ChatMessageRoleUser,
ai.RoleModel: deepseek.ChatMessageRoleAssistant,
......@@ -37,7 +28,6 @@ var (
// DeepSeek holds configuration for the plugin.
type DeepSeek struct {
APIKey string // DeepSeek API key
//ServerAddress string
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
......@@ -54,11 +44,16 @@ type ModelDefinition struct {
Type string
}
// // DefineModel defines a DeepSeek model in Genkit.
// DefineModel defines a DeepSeek model in Genkit.
func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
log.Info().
Str("method", "DeepSeek.DefineModel").
Str("model_name", model.Name).
Msg("Defining DeepSeek model")
d.mu.Lock()
defer d.mu.Unlock()
if !d.initted {
log.Error().Str("method", "DeepSeek.DefineModel").Msg("DeepSeek not initialized")
panic("deepseek.Init not called")
}
......@@ -78,27 +73,35 @@ func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai
}
meta := &ai.ModelInfo{
// Label: "DeepSeek - " + model.Name,
Label: model.Name,
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, apiKey: d.APIKey}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
modelDef := genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
log.Info().
Str("method", "DeepSeek.DefineModel").
Str("model_name", model.Name).
Msg("Model defined successfully")
return modelDef
}
// Init initializes the DeepSeek plugin.
func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
log.Info().Str("method", "DeepSeek.Init").Msg("Initializing DeepSeek plugin")
d.mu.Lock()
defer d.mu.Unlock()
if d.initted {
panic("deepseek.Init already called")
log.Error().Str("method", "DeepSeek.Init").Msg("Plugin already initialized")
return fmt.Errorf("deepseek.Init already called")
}
if d == nil || d.APIKey == "" {
log.Error().Str("method", "DeepSeek.Init").Msg("APIKey is required")
return fmt.Errorf("deepseek: need APIKey")
}
d.initted = true
log.Info().Str("method", "DeepSeek.Init").Msg("Initialization successful")
return nil
}
......@@ -110,14 +113,21 @@ type generator struct {
// generate implements the Genkit model generation interface.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
log.Info().
Str("method", "generator.generate").
Str("model_name", g.model.Name).
Int("messages", len(input.Messages)).
Msg("Starting model generation")
// stream := cb != nil
if len(input.Messages) == 0 {
log.Error().Str("method", "generator.generate").Msg("Prompt or messages required")
return nil, fmt.Errorf("prompt or messages required")
}
// Set up the Deepseek client
// Initialize DeepSeek client.
client := deepseek.NewClient(g.apiKey)
log.Debug().Str("method", "generator.generate").Msg("DeepSeek client initialized")
// Create a chat completion request
request := &deepseek.ChatCompletionRequest{
Model: g.model.Name,
......@@ -126,6 +136,10 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
for _, msg := range input.Messages {
role, ok := roleMapping[msg.Role]
if !ok {
log.Error().
Str("method", "generator.generate").
Str("role", string(msg.Role)).
Msg("Unsupported role")
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
content := concatMessageParts(msg.Content)
......@@ -133,16 +147,27 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
Role: role,
Content: content,
})
log.Debug().
Str("method", "generator.generate").
Str("role", role).
Str("content", content).
Msg("Added message to request")
}
// Send the request and handle the response
response, err := client.CreateChatCompletion(ctx, request)
if err != nil {
log.Fatalf("error: %v", err)
log.Error().
Err(err).
Str("method", "generator.generate").
Msg("Failed to create chat completion")
return nil, fmt.Errorf("create chat completion: %w", err)
}
// Print the response
fmt.Println("Response:", response.Choices[0].Message.Content)
log.Debug().
Str("method", "generator.generate").
Int("choices", len(response.Choices)).
Msg("Received chat completion response")
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
......@@ -154,18 +179,32 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
}
for _, chunk := range response.Choices {
log.Debug().
Str("method", "generator.generate").
Int("index", chunk.Index).
Str("content", chunk.Message.Content).
Msg("Processing response chunk")
p := ai.Part{
Text: chunk.Message.Content,
Kind: ai.PartKind(chunk.Index),
}
finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
}
return finalResponse, nil // Return the final merged response
log.Info().
Str("method", "generator.generate").
Str("model_name", g.model.Name).
Int("content_parts", len(finalResponse.Message.Content)).
Msg("Model generation completed successfully")
return finalResponse, nil
}
// concatMessageParts concatenates message parts into a single string.
func concatMessageParts(parts []*ai.Part) string {
log.Debug().
Str("method", "concatMessageParts").
Int("parts", len(parts)).
Msg("Concatenating message parts")
var sb strings.Builder
for _, part := range parts {
if part.IsText() {
......@@ -173,9 +212,213 @@ func concatMessageParts(parts []*ai.Part) string {
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
return sb.String()
result := sb.String()
log.Debug().
Str("method", "concatMessageParts").
Str("result", result).
Msg("Concatenation complete")
return result
}
// package deepseek
// import (
// "context"
// "fmt"
// "log"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// deepseek "github.com/cohesion-org/deepseek-go"
// )
// const provider = "deepseek"
// var (
// mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// // toolSupportedModels = []string{
// // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
// // "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
// // "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
// // "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
// // "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
// // "granite3.3", "command-a", "command-r7b-arabic",
// // }
// roleMapping = map[ai.Role]string{
// ai.RoleUser: deepseek.ChatMessageRoleUser,
// ai.RoleModel: deepseek.ChatMessageRoleAssistant,
// ai.RoleSystem: deepseek.ChatMessageRoleSystem,
// ai.RoleTool: deepseek.ChatMessageRoleTool,
// }
// )
// // DeepSeek holds configuration for the plugin.
// type DeepSeek struct {
// APIKey string // DeepSeek API key
// //ServerAddress string
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the provider name.
// func (d DeepSeek) Name() string {
// return provider
// }
// // ModelDefinition represents a model with its name and type.
// type ModelDefinition struct {
// Name string
// Type string
// }
// // // DefineModel defines a DeepSeek model in Genkit.
// func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// d.mu.Lock()
// defer d.mu.Unlock()
// if !d.initted {
// panic("deepseek.Init not called")
// }
// // Define model info, supporting multiturn and system role.
// mi := ai.ModelInfo{
// Label: model.Name,
// Supports: &ai.ModelSupports{
// Multiturn: true,
// SystemRole: true,
// Media: false, // DeepSeek API primarily supports text.
// Tools: false, // Tools not yet supported in this implementation.
// },
// Versions: []string{},
// }
// if info != nil {
// mi = *info
// }
// meta := &ai.ModelInfo{
// // Label: "DeepSeek - " + model.Name,
// Label: model.Name,
// Supports: mi.Supports,
// Versions: []string{},
// }
// gen := &generator{model: model, apiKey: d.APIKey}
// return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
// }
// // Init initializes the DeepSeek plugin.
// func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
// d.mu.Lock()
// defer d.mu.Unlock()
// if d.initted {
// panic("deepseek.Init already called")
// }
// if d == nil || d.APIKey == "" {
// return fmt.Errorf("deepseek: need APIKey")
// }
// d.initted = true
// return nil
// }
// // generator handles model generation.
// type generator struct {
// model ModelDefinition
// apiKey string
// }
// // generate implements the Genkit model generation interface.
// func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// // stream := cb != nil
// if len(input.Messages) == 0 {
// return nil, fmt.Errorf("prompt or messages required")
// }
// // Set up the Deepseek client
// // Initialize DeepSeek client.
// client := deepseek.NewClient(g.apiKey)
// // Create a chat completion request
// request := &deepseek.ChatCompletionRequest{
// Model: g.model.Name,
// }
// for _, msg := range input.Messages {
// role, ok := roleMapping[msg.Role]
// if !ok {
// return nil, fmt.Errorf("unsupported role: %s", msg.Role)
// }
// content := concatMessageParts(msg.Content)
// request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
// Role: role,
// Content: content,
// })
// }
// // Send the request and handle the response
// response, err := client.CreateChatCompletion(ctx, request)
// if err != nil {
// log.Fatalf("error: %v", err)
// }
// // Print the response
// fmt.Println("Response:", response.Choices[0].Message.Content)
// // Create a final response with the merged chunks
// finalResponse := &ai.ModelResponse{
// Request: input,
// FinishReason: ai.FinishReason("stop"),
// Message: &ai.Message{
// Role: ai.RoleModel,
// },
// }
// for _, chunk := range response.Choices {
// p := ai.Part{
// Text: chunk.Message.Content,
// Kind: ai.PartKind(chunk.Index),
// }
// finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
// }
// return finalResponse, nil // Return the final merged response
// }
// // concatMessageParts concatenates message parts into a single string.
// func concatMessageParts(parts []*ai.Part) string {
// var sb strings.Builder
// for _, part := range parts {
// if part.IsText() {
// sb.WriteString(part.Text)
// }
// // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
// }
// return sb.String()
// }
/*
// Choice represents a completion choice generated by the model.
......
......@@ -32,7 +32,7 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/rs/zerolog/log"
"github.com/wade-liwei/agentchat/util"
)
......@@ -86,6 +86,11 @@ type SyncBatchRequest struct {
// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
log.Info().
Str("method", "NewClient").
Str("ip", ip).
Int("port", port).
Msg("Creating new GraphRAG client")
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
......@@ -93,14 +98,21 @@ func NewClient(ip string, port int) *Client {
// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
log.Info().
Str("method", "Client.AddSpace").
Str("name", req.Name).
Str("owner", req.Owner).
Msg("Adding knowledge space")
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to marshal request")
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to create request")
return nil, fmt.Errorf("failed to create request: %w", err)
}
......@@ -110,22 +122,34 @@ func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddSpace").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.AddSpace").
Int("status_code", resp.StatusCode).
Msg("Space addition request completed")
return resp, nil
}
// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
log.Info().
Str("method", "Client.AddDocument").
Str("space_id", spaceID).
Str("doc_name", req.DocName).
Msg("Adding document")
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to marshal request")
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to create request")
return nil, fmt.Errorf("failed to create request: %w", err)
}
......@@ -135,9 +159,15 @@ func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Respons
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.AddDocument").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.AddDocument").
Str("space_id", spaceID).
Int("status_code", resp.StatusCode).
Msg("Document addition request completed")
return resp, nil
}
......@@ -148,17 +178,24 @@ type SyncDocumentsRequest struct {
// SyncDocuments sends a POST request to sync documents for the given spaceID.
func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, err error) {
log.Info().
Str("method", "Client.SyncDocuments").
Str("space_id", spaceID).
Strs("doc_ids", docIDs).
Msg("Syncing documents")
url := fmt.Sprintf("%s/knowledge/%s/document/sync", c.BaseURL, spaceID)
reqBody := SyncDocumentsRequest{
DocIDs: docIDs,
}
body, err := json.Marshal(reqBody)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to marshal request")
return false, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to create request")
return false, fmt.Errorf("failed to create request: %w", err)
}
......@@ -168,32 +205,50 @@ func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, e
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to send request")
return false, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncDocuments").Msg("Failed to read response body")
return false, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Error().
Str("method", "Client.SyncDocuments").
Int("status_code", resp.StatusCode).
Str("response_body", string(respBody)).
Msg("Sync request failed")
return false, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
}
return success, nil
log.Info().
Str("method", "Client.SyncDocuments").
Str("space_id", spaceID).
Msg("Documents synced successfully")
return true, nil
}
// SyncBatchDocument 同步批量处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
log.Info().
Str("method", "Client.SyncBatchDocument").
Str("space_id", spaceID).
Int("requests", len(req)).
Msg("Syncing batch documents")
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to marshal request")
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to create request")
return nil, fmt.Errorf("failed to create request: %w", err)
}
......@@ -203,9 +258,15 @@ func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*htt
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "Client.SyncBatchDocument").Msg("Failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
log.Info().
Str("method", "Client.SyncBatchDocument").
Str("space_id", spaceID).
Int("status_code", resp.StatusCode).
Msg("Batch document sync request completed")
return resp, nil
}
......@@ -235,6 +296,7 @@ func (k *GraphKnowledge) Name() string {
// Init initializes the GraphKnowledge plugin.
func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
log.Info().Str("method", "GraphKnowledge.Init").Msg("Initializing GraphKnowledge plugin")
if k == nil {
k = &GraphKnowledge{}
}
......@@ -243,7 +305,10 @@ func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error)
defer k.mu.Unlock()
defer func() {
if err != nil {
log.Error().Err(err).Str("method", "GraphKnowledge.Init").Msg("Initialization failed")
err = fmt.Errorf("graphrag.Init: %w", err)
} else {
log.Info().Str("method", "GraphKnowledge.Init").Msg("Initialization successful")
}
}()
......@@ -269,31 +334,46 @@ func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error)
func parseAddr(addr string) (string, int) {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
log.Warn().
Str("method", "parseAddr").
Str("addr", addr).
Msg("Invalid address format, using default")
return "54.92.111.204", 5670
}
port, err := strconv.Atoi(parts[1])
if err != nil {
log.Error().
Err(err).
Str("method", "parseAddr").
Str("port", parts[1]).
Msg("Failed to parse port, using default")
return "54.92.111.204", 5670
}
port, _ := strconv.Atoi(parts[1])
return parts[0], port
}
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {
log.Info().Str("method", "DefineIndexerAndRetriever").Msg("Defining indexer and retriever")
spaceID := ""
modelName := "Qwen/Qwen2.5-Coder-32B-Instruct"
k := genkit.LookupPlugin(g, provider)
if k == nil {
log.Error().Str("method", "DefineIndexerAndRetriever").Msg("GraphRAG plugin not found")
return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
}
knowledge := k.(*GraphKnowledge)
ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
if err != nil {
log.Error().Err(err).Str("method", "DefineIndexerAndRetriever").Msg("Failed to create doc store")
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
log.Info().Str("method", "DefineIndexerAndRetriever").Msg("Indexer and retriever defined successfully")
return indexer, retriever, nil
}
......@@ -306,10 +386,17 @@ type docStore struct {
// newDocStore creates a docStore.
func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
log.Info().
Str("method", "GraphKnowledge.newDocStore").
Str("space_id", spaceID).
Str("model_name", modelName).
Msg("Creating new doc store")
if k.client == nil {
log.Error().Str("method", "GraphKnowledge.newDocStore").Msg("GraphRAG client not initialized")
return nil, errors.New("graphrag.Init not called")
}
log.Info().Str("method", "GraphKnowledge.newDocStore").Msg("Doc store created successfully")
return &docStore{
client: k.client,
spaceID: spaceID,
......@@ -319,16 +406,42 @@ func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName str
// Indexer returns the indexer for a space.
func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
return genkit.LookupIndexer(g, provider, spaceID)
log.Info().
Str("method", "Indexer").
Str("space_id", spaceID).
Msg("Looking up indexer")
indexer := genkit.LookupIndexer(g, provider, spaceID)
if indexer == nil {
log.Warn().
Str("method", "Indexer").
Str("space_id", spaceID).
Msg("Indexer not found")
}
return indexer
}
// Retriever returns the retriever for a space.
func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
return genkit.LookupRetriever(g, provider, spaceID)
log.Info().
Str("method", "Retriever").
Str("space_id", spaceID).
Msg("Looking up retriever")
retriever := genkit.LookupRetriever(g, provider, spaceID)
if retriever == nil {
log.Warn().
Str("method", "Retriever").
Str("space_id", spaceID).
Msg("Retriever not found")
}
return retriever
}
// generateRandomDocName generates a random alphanumeric string of the specified length.
func GenerateRandomDocName(length int) (string, error) {
log.Debug().
Str("method", "GenerateRandomDocName").
Int("length", length).
Msg("Generating random document name")
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
var result strings.Builder
result.Grow(length)
......@@ -336,28 +449,51 @@ func GenerateRandomDocName(length int) (string, error) {
for i := 0; i < length; i++ {
idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
log.Error().
Err(err).
Str("method", "GenerateRandomDocName").
Msg("Failed to generate random index")
return "", fmt.Errorf("failed to generate random index: %w", err)
}
result.WriteByte(charset[idx.Int64()])
}
return result.String(), nil
docName := result.String()
log.Debug().
Str("method", "GenerateRandomDocName").
Str("doc_name", docName).
Msg("Generated document name")
return docName, nil
}
// ParseJSONResponse parses a JSON byte slice and extracts the success boolean and data fields as a string.
func ParseJSONResponse(jsonBytes []byte) (success bool, data string, err error) {
log.Debug().
Str("method", "ParseJSONResponse").
Str("json", string(jsonBytes)).
Msg("Parsing JSON response")
// Define struct to capture only the needed fields
type jsonResponse struct {
Success bool `json:"success"`
Data int `json:"data"` // Use string to capture JSON string data
Data int `json:"data"`
}
var resp jsonResponse
if err := json.Unmarshal(jsonBytes, &resp); err != nil {
log.Error().
Err(err).
Str("method", "ParseJSONResponse").
Msg("Failed to unmarshal JSON")
return false, "", fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return resp.Success, fmt.Sprintf("%d", resp.Data), nil
dataStr := fmt.Sprintf("%d", resp.Data)
log.Debug().
Str("method", "ParseJSONResponse").
Bool("success", resp.Success).
Str("data", dataStr).
Msg("Parsed JSON response")
return resp.Success, dataStr, nil
}
type IndexReqOption struct {
......@@ -369,21 +505,37 @@ const DocNameKey = "doc_name"
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
log.Info().
Str("method", "docStore.Index").
Str("space_id", ds.spaceID).
Int("documents", len(req.Documents)).
Msg("Starting index operation")
if len(req.Documents) == 0 {
log.Debug().
Str("method", "docStore.Index").
Str("space_id", ds.spaceID).
Msg("No documents to index")
return nil
}
// Type-assert req.Options to IndexReqOption
opt, ok := req.Options.(*IndexReqOption)
if !ok {
log.Error().
Str("method", "docStore.Index").
Str("options_type", fmt.Sprintf("%T", req.Options)).
Msg("Invalid options type")
return fmt.Errorf("invalid options type: got %T, want *IndexReqOption", req.Options)
}
// Validate required fields
if opt.UserId == "" {
log.Error().Str("method", "docStore.Index").Msg("UserId is required")
return fmt.Errorf("UserId is required in IndexReqOption")
}
if opt.UserName == "" {
log.Error().Str("method", "docStore.Index").Msg("UserName is required")
return fmt.Errorf("UserName is required in IndexReqOption")
}
......@@ -397,33 +549,47 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
}
resp, err := ds.client.AddSpace(spaceReq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Index").Msg("Failed to add space")
return fmt.Errorf("add space: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Error().
Str("method", "docStore.Index").
Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
Msg("Add space failed")
return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
}
fmt.Println("space ok")
log.Info().Str("method", "docStore.Index").Str("space_id", opt.UserId).Msg("Space created successfully")
spaceId := opt.UserId
// Index each document
for i, doc := range req.Documents {
// Use DocName from options, fall back to random name if empty
docName := ""
if v, ok := doc.Metadata[DocNameKey]; ok {
if str, isString := v.(string); isString {
docName = str
} else {
return fmt.Errorf("must provide doc_name str value in metadata")
}
} else {
// Use DocName from metadata
docName, ok := doc.Metadata[DocNameKey].(string)
if !ok {
log.Error().
Str("method", "docStore.Index").
Int("index", i).
Msg("Missing doc_name in metadata")
return fmt.Errorf("must provide doc_name key in metadata")
}
if docName == "" {
log.Error().
Str("method", "docStore.Index").
Int("index", i).
Msg("doc_name is empty")
return fmt.Errorf("must provide non-empty doc_name str value in metadata")
}
fmt.Println("docName: ", docName)
log.Debug().
Str("method", "docStore.Index").
Int("index", i).
Str("doc_name", docName).
Msg("Processing document")
// Add document
var sb strings.Builder
......@@ -433,7 +599,12 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
}
}
text := sb.String()
fmt.Println("text: ", text)
log.Debug().
Str("method", "docStore.Index").
Int("index", i).
Str("text", text).
Msg("Extracted document text")
docReq := DocumentRequest{
DocName: docName,
Source: "api",
......@@ -444,313 +615,81 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
}
resp, err := ds.client.AddDocument(spaceId, docReq)
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to add document")
return fmt.Errorf("add document %d: %w", i+1, err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to read add document response")
return fmt.Errorf("read add document response %d: %w", i+1, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Error().
Str("method", "docStore.Index").
Int("index", i+1).
Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
Msg("Add document failed")
return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
}
// Parse AddDocument response
ok, idx, err := ParseJSONResponse(body)
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to parse add document")
return fmt.Errorf("parse add document response %d: %w", i+1, err)
}
if !ok {
log.Error().
Str("method", "docStore.Index").
Int("index", i+1).
Str("data", idx).
Msg("Add document response indicated failure")
return fmt.Errorf("add document %d failed: response success=false, data=%s", i+1, idx)
}
fmt.Println("document ok", string(body), idx)
log.Info().
Str("method", "docStore.Index").
Int("index", i+1).
Str("doc_id", idx).
Msg("Document added successfully")
// Sync document
_, err = ds.client.SyncDocuments(spaceId, []string{idx})
if err != nil {
log.Error().
Err(err).
Str("method", "docStore.Index").
Int("index", i+1).
Msg("Failed to sync document")
return fmt.Errorf("sync document %d: %w", i+1, err)
}
}
log.Info().
Str("method", "docStore.Index").
Str("space_id", ds.spaceID).
Int("documents", len(req.Documents)).
Msg("Index operation completed successfully")
return nil
}
// type IndexReqOption struct{
// UserId string
// UserName string
// DocName string
// }
// // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 {
// return nil
// }
// req.Options
// userid := ""
// usernmae := ""
// for _, doc := range req.Documents {
// if v, ok := doc.Metadata["user_id"]; ok {
// if str, isString := v.(string); isString {
// userid = str
// }
// }
// if v, ok := doc.Metadata["username"]; ok {
// if str, isString := v.(string); isString {
// usernmae = str
// }
// }
// }
// // Create knowledge space.
// spaceReq := SpaceRequest{
// Name: userid,
// VectorType: "KnowledgeGraph",
// DomainType: "Normal",
// Desc: usernmae,
// Owner: userid,
// }
// resp, err := ds.client.AddSpace(spaceReq)
// if err != nil {
// return fmt.Errorf("add space: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// body, _ := io.ReadAll(resp.Body)
// return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
// }
// fmt.Println("space ok")
// spaceId := userid
// // Index each document.
// for i, doc := range req.Documents {
// docName := ""
// if v, ok := doc.Metadata["doc_name"]; ok {
// if str, isString := v.(string); isString {
// docName = str
// } else {
// // Generate random docName.
// var err error
// docName, err = generateRandomDocName(8)
// if err != nil {
// return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// }
// }
// } else {
// // Generate random docName.
// var err error
// docName, err = generateRandomDocName(8)
// if err != nil {
// return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// }
// }
// fmt.Println("docName: ", docName)
// // Add document.
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// fmt.Println("text: ",text)
// docReq := DocumentRequest{
// DocName: docName,
// Source: "api",
// DocType: "TEXT",
// Content: text,
// Labels: "",
// // Questions: []string{},
// Metadata: doc.Metadata,
// }
// resp, err := ds.client.AddDocument(spaceId, docReq)
// if err != nil {
// return fmt.Errorf("add document %d: %w", i+1, err)
// }
// body, _ := io.ReadAll(resp.Body)
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
// }
// ok, idx, err := ParseJSONResponse(body)
// if err != nil {
// return fmt.Errorf("ParseJSONResponse %d: %w", i+1, err)
// }
// if !ok{
// return fmt.Errorf("ParseJSONResponse body %d: %w", i+1, err)
// }
// fmt.Println("document ok",string(body),idx)
// ok ,err =ds.client.SyncDocuments(spaceId,[]string{idx})
// if err != nil{
// return err
// }
// }
// return nil
// }
// // RetrieverOptions for Knowledge retrieval.
// type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve.
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// }
// // Assuming ai.Part has a Text() method or Text field to get the string content
// func partsToString(parts []*ai.Part) string {
// var texts []string
// for _, part := range parts {
// // Adjust this based on the actual ai.Part structure
// // If ai.Part has a Text() method:
// texts = append(texts, part.Text)
// // OR if ai.Part has a Text field:
// // texts = append(texts, part.Text)
// }
// return strings.Join(texts, " ")
// }
// // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// // count := 3
// // metricTypeStr := "L2"
// // if req.Options != nil {
// // ropt, ok := req.Options.(*RetrieverOptions)
// // if !ok {
// // return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// // }
// // if ropt.Count > 0 {
// // count = ropt.Count
// // }
// // if ropt.MetricType != "" {
// // metricTypeStr = ropt.MetricType
// // }
// // }
// queryContent := partsToString(req.Query.Content)
// // Format query for retrieval.
// queryText := fmt.Sprintf("Search for: %s", queryContent)
// username := "Alice" // Default, override if metadata available.
// if req.Query.Metadata != nil {
// if uname, ok := req.Query.Metadata["username"].(string); ok {
// username = uname
// }
// }
// // Prepare request for chat completions endpoint.
// url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL)
// chatReq := struct {
// ConvUID string `json:"conv_uid"`
// UserInput string `json:"user_input"`
// UserName string `json:"user_name"`
// ChatMode string `json:"chat_mode"`
// AppCode string `json:"app_code"`
// Temperature float32 `json:"temperature"`
// MaxNewTokens int `json:"max_new_tokens"`
// SelectParam string `json:"select_param"`
// ModelName string `json:"model_name"`
// Incremental bool `json:"incremental"`
// SysCode string `json:"sys_code"`
// PromptCode string `json:"prompt_code"`
// ExtInfo map[string]interface{} `json:"ext_info"`
// }{
// ConvUID: "",
// UserInput: queryText,
// UserName: username,
// ChatMode: "",
// AppCode: "",
// Temperature: 0.5,
// MaxNewTokens: 4000,
// SelectParam: "",
// ModelName: ds.modelName,
// Incremental: false,
// SysCode: "",
// PromptCode: "",
// ExtInfo: map[string]interface{}{
// "space_id": ds.spaceID,
// //"k": count,
// },
// }
// body, err := json.Marshal(chatReq)
// if err != nil {
// return nil, fmt.Errorf("marshal chat request: %w", err)
// }
// httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("create chat request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("send chat request: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// body, _ := io.ReadAll(resp.Body)
// return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
// }
// // Parse response
// var chatResp struct {
// Success bool `json:"success"`
// Data struct {
// Answer []struct {
// Content string `json:"content"`
// DocID int `json:"doc_id"`
// Score float64 `json:"score"`
// Metadata map[string]interface{} `json:"metadata_map"`
// } `json:"answer"`
// } `json:"data"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
// return nil, fmt.Errorf("decode chat response: %w", err)
// }
// var docs []*ai.Document
// for _, doc := range chatResp.Data.Answer {
// metadata := doc.Metadata
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// // Ensure metadata includes user_id and username.
// if _, ok := metadata["user_id"]; !ok {
// metadata["user_id"] = "user123"
// }
// if _, ok := metadata["username"]; !ok {
// metadata["username"] = username
// }
// aiDoc := ai.DocumentFromText(doc.Content, metadata)
// docs = append(docs, aiDoc)
// }
// return &ai.RetrieverResponse{
// Documents: docs,
// }, nil
// }
// ChatRequest 定义请求结构体,匹配单元测试的 curl 请求
// ChatRequest defines the request structure for chat completions.
type ChatRequest struct {
Model string `json:"model"`
Messages string `json:"messages"`
......@@ -758,7 +697,7 @@ type ChatRequest struct {
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
N int `json:"n"`
MaxTokens int `json:"max_tokens"`
MaxTokens int64 `json:"max_tokens"`
Stream bool `json:"stream"`
RepetitionPenalty float64 `json:"repetition_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
......@@ -768,7 +707,7 @@ type ChatRequest struct {
EnableVis bool `json:"enable_vis"`
}
// ChatResponse 定义响应结构体,匹配单元测试的 API 响应
// ChatResponse defines the response structure from the API.
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
......@@ -790,42 +729,64 @@ type ChatResponse struct {
} `json:"usage"`
}
// Assuming ai.Part has a Text() method or Text field to get the string content
// Assuming ai.Part has a Text() method or Text field to get string content.
func partsToString(parts []*ai.Part) string {
log.Debug().
Str("method", "partsToString").
Int("parts", len(parts)).
Msg("Converting parts to string")
var texts []string
for _, part := range parts {
// Adjust this based on the actual ai.Part structure
// If ai.Part has a Text() method:
texts = append(texts, part.Text)
// OR if ai.Part has a Text field:
// texts = append(texts, part.Text)
}
return strings.Join(texts, " ")
result := strings.Join(texts, " ")
log.Debug().
Str("method", "partsToString").
Str("result", result).
Msg("Conversion complete")
return result
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
log.Info().
Str("method", "docStore.Retrieve").
Str("space_id", ds.spaceID).
Msg("Starting retrieve operation")
// Format query for retrieval.
queryContent := partsToString(req.Query.Content)
queryText := fmt.Sprintf("Search for: %s", queryContent)
log.Debug().
Str("method", "docStore.Retrieve").
Str("query", queryText).
Msg("Formatted query")
if req.Query.Metadata == nil {
// If ok, we don't use the User struct since the requirement is to error on non-nil
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
log.Error().
Str("method", "docStore.Retrieve").
Str("metadata_type", fmt.Sprintf("%T", req.Query.Metadata)).
Msg("Query metadata is nil")
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
for k, v := range req.Query.Metadata {
fmt.Println("k", k, "v", v)
log.Debug().
Str("method", "docStore.Retrieve").
Str("key", k).
Interface("value", v).
Msg("Metadata entry")
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Msg("Missing username in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Msg("Missing user_id in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
......@@ -848,13 +809,21 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
EnableVis: true,
}
log.Debug().
Str("method", "docStore.Retrieve").
Str("url", url).
Interface("chat_request", chatReq).
Msg("Preparing chat completion request")
body, err := json.Marshal(chatReq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to marshal chat request")
return nil, fmt.Errorf("marshal chat request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to create chat request")
return nil, fmt.Errorf("create chat request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
......@@ -863,22 +832,42 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to send chat request")
return nil, fmt.Errorf("send chat request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Error().
Str("method", "docStore.Retrieve").
Int("status_code", resp.StatusCode).
Str("response_body", string(body)).
Msg("Chat completion failed")
return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var chatResp ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(body))
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Msg("Failed to read response body")
return nil, fmt.Errorf("read chat response body: %w", err)
}
if err := json.Unmarshal(respBody, &chatResp); err != nil {
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("raw_response", string(respBody)).
Msg("Failed to decode chat response")
return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(respBody))
}
log.Debug().
Str("method", "docStore.Retrieve").
Interface("chat_response", chatResp).
Msg("Parsed chat response")
// Convert response to ai.Document
var docs []*ai.Document
if len(chatResp.Choices) > 0 {
......@@ -889,9 +878,935 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
}
aiDoc := ai.DocumentFromText(content, metadata)
docs = append(docs, aiDoc)
log.Debug().
Str("method", "docStore.Retrieve").
Str("content", content).
Interface("metadata", metadata).
Msg("Created document from response")
}
log.Info().
Str("method", "docStore.Retrieve").
Str("space_id", ds.spaceID).
Int("documents", len(docs)).
Msg("Retrieve operation completed successfully")
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package graphrag
// import (
// "bytes"
// "context"
// "crypto/rand"
// "encoding/json"
// "errors"
// "fmt"
// "io"
// "math/big"
// "net/http"
// "strconv"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/wade-liwei/agentchat/util"
// )
// // Client 知识库客户端
// type Client struct {
// BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
// }
// // SpaceRequest 创建空间的请求结构体
// type SpaceRequest struct {
// ID int `json:"id"`
// Name string `json:"name"`
// VectorType string `json:"vector_type"`
// DomainType string `json:"domain_type"`
// Desc string `json:"desc"`
// Owner string `json:"owner"`
// SpaceID int `json:"space_id"`
// }
// // DocumentRequest 添加文档的请求结构体
// type DocumentRequest struct {
// DocName string `json:"doc_name"`
// DocID int `json:"doc_id"`
// DocType string `json:"doc_type"`
// DocToken string `json:"doc_token"`
// Content string `json:"content"`
// Source string `json:"source"`
// Labels string `json:"labels"`
// Questions []string `json:"questions"`
// Metadata map[string]interface{} `json:"metadata"`
// }
// // ChunkParameters 分片参数
// type ChunkParameters struct {
// ChunkStrategy string `json:"chunk_strategy"`
// TextSplitter string `json:"text_splitter"`
// SplitterType string `json:"splitter_type"`
// ChunkSize int `json:"chunk_size"`
// ChunkOverlap int `json:"chunk_overlap"`
// Separator string `json:"separator"`
// EnableMerge bool `json:"enable_merge"`
// }
// // SyncBatchRequest 同步批量处理的请求结构体
// type SyncBatchRequest struct {
// DocID int `json:"doc_id"`
// SpaceID string `json:"space_id"`
// ModelName string `json:"model_name"`
// ChunkParameters ChunkParameters `json:"chunk_parameters"`
// }
// // NewClient 创建新的客户端实例
// func NewClient(ip string, port int) *Client {
// return &Client{
// BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
// }
// }
// // AddSpace 创建知识空间
// func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // AddDocument 添加文档
// func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // SyncDocumentsRequest defines the request body for the sync documents endpoint.
// type SyncDocumentsRequest struct {
// DocIDs []string `json:"doc_ids"`
// }
// // SyncDocuments sends a POST request to sync documents for the given spaceID.
// func (c *Client) SyncDocuments(spaceID string, docIDs []string) (success bool, err error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/sync", c.BaseURL, spaceID)
// reqBody := SyncDocumentsRequest{
// DocIDs: docIDs,
// }
// body, err := json.Marshal(reqBody)
// if err != nil {
// return false, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return false, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return false, fmt.Errorf("failed to send request: %w", err)
// }
// defer resp.Body.Close()
// respBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return false, fmt.Errorf("failed to read response body: %w", err)
// }
// if resp.StatusCode != http.StatusOK {
// return false, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
// }
// return success, nil
// }
// // SyncBatchDocument 同步批量处理文档
// func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
// url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
// body, err := json.Marshal(req)
// if err != nil {
// return nil, fmt.Errorf("failed to marshal request: %w", err)
// }
// httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("failed to send request: %w", err)
// }
// return resp, nil
// }
// // The provider used in the registry.
// const provider = "graphrag"
// // Field names for schema.
// const (
// idField = "id"
// textField = "text"
// metadataField = "metadata"
// )
// // GraphKnowledge holds configuration for the plugin.
// type GraphKnowledge struct {
// Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670").
// client *Client // Knowledge client.
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the plugin name.
// func (k *GraphKnowledge) Name() string {
// return provider
// }
// // Init initializes the GraphKnowledge plugin.
// func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
// if k == nil {
// k = &GraphKnowledge{}
// }
// k.mu.Lock()
// defer k.mu.Unlock()
// defer func() {
// if err != nil {
// err = fmt.Errorf("graphrag.Init: %w", err)
// }
// }()
// if k.initted {
// return errors.New("plugin already initialized")
// }
// // Load configuration.
// addr := k.Addr
// if addr == "" {
// addr = "54.92.111.204:5670" // Default address.
// }
// // Initialize Knowledge client.
// host, port := parseAddr(addr)
// client := NewClient(host, port)
// k.client = client
// k.initted = true
// return nil
// }
// // parseAddr splits host:port into host and port.
// func parseAddr(addr string) (string, int) {
// parts := strings.Split(addr, ":")
// if len(parts) != 2 {
// return "54.92.111.204", 5670
// }
// port, _ := strconv.Atoi(parts[1])
// return parts[0], port
// }
// // DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
// func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {
// spaceID := ""
// modelName := "Qwen/Qwen2.5-Coder-32B-Instruct"
// k := genkit.LookupPlugin(g, provider)
// if k == nil {
// return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
// }
// knowledge := k.(*GraphKnowledge)
// ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
// if err != nil {
// return nil, nil, err
// }
// indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
// retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
// return indexer, retriever, nil
// }
// // docStore defines an Indexer and a Retriever.
// type docStore struct {
// client *Client
// spaceID string
// modelName string
// }
// // newDocStore creates a docStore.
// func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
// if k.client == nil {
// return nil, errors.New("graphrag.Init not called")
// }
// return &docStore{
// client: k.client,
// spaceID: spaceID,
// modelName: modelName,
// }, nil
// }
// // Indexer returns the indexer for a space.
// func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
// return genkit.LookupIndexer(g, provider, spaceID)
// }
// // Retriever returns the retriever for a space.
// func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
// return genkit.LookupRetriever(g, provider, spaceID)
// }
// // generateRandomDocName generates a random alphanumeric string of the specified length.
// func GenerateRandomDocName(length int) (string, error) {
// const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
// var result strings.Builder
// result.Grow(length)
// for i := 0; i < length; i++ {
// idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
// if err != nil {
// return "", fmt.Errorf("failed to generate random index: %w", err)
// }
// result.WriteByte(charset[idx.Int64()])
// }
// return result.String(), nil
// }
// // ParseJSONResponse parses a JSON byte slice and extracts the success boolean and data fields as a string.
// func ParseJSONResponse(jsonBytes []byte) (success bool, data string, err error) {
// // Define struct to capture only the needed fields
// type jsonResponse struct {
// Success bool `json:"success"`
// Data int `json:"data"` // Use string to capture JSON string data
// }
// var resp jsonResponse
// if err := json.Unmarshal(jsonBytes, &resp); err != nil {
// return false, "", fmt.Errorf("failed to unmarshal JSON: %w", err)
// }
// return resp.Success, fmt.Sprintf("%d", resp.Data), nil
// }
// type IndexReqOption struct {
// UserId string
// UserName string
// }
// const DocNameKey = "doc_name"
// // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 {
// return nil
// }
// // Type-assert req.Options to IndexReqOption
// opt, ok := req.Options.(*IndexReqOption)
// if !ok {
// return fmt.Errorf("invalid options type: got %T, want *IndexReqOption", req.Options)
// }
// // Validate required fields
// if opt.UserId == "" {
// return fmt.Errorf("UserId is required in IndexReqOption")
// }
// if opt.UserName == "" {
// return fmt.Errorf("UserName is required in IndexReqOption")
// }
// // Create knowledge space
// spaceReq := SpaceRequest{
// Name: opt.UserId,
// VectorType: "KnowledgeGraph",
// DomainType: "Normal",
// Desc: opt.UserName,
// Owner: opt.UserId,
// }
// resp, err := ds.client.AddSpace(spaceReq)
// if err != nil {
// return fmt.Errorf("add space: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// body, _ := io.ReadAll(resp.Body)
// return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
// }
// fmt.Println("space ok")
// spaceId := opt.UserId
// // Index each document
// for i, doc := range req.Documents {
// // Use DocName from options, fall back to random name if empty
// docName := ""
// if v, ok := doc.Metadata[DocNameKey]; ok {
// if str, isString := v.(string); isString {
// docName = str
// } else {
// return fmt.Errorf("must provide doc_name str value in metadata")
// }
// } else {
// return fmt.Errorf("must provide doc_name key in metadata")
// }
// fmt.Println("docName: ", docName)
// // Add document
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// fmt.Println("text: ", text)
// docReq := DocumentRequest{
// DocName: docName,
// Source: "api",
// DocType: "TEXT",
// Content: text,
// Labels: "",
// Metadata: doc.Metadata,
// }
// resp, err := ds.client.AddDocument(spaceId, docReq)
// if err != nil {
// return fmt.Errorf("add document %d: %w", i+1, err)
// }
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// resp.Body.Close()
// return fmt.Errorf("read add document response %d: %w", i+1, err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
// }
// // Parse AddDocument response
// ok, idx, err := ParseJSONResponse(body)
// if err != nil {
// return fmt.Errorf("parse add document response %d: %w", i+1, err)
// }
// if !ok {
// return fmt.Errorf("add document %d failed: response success=false, data=%s", i+1, idx)
// }
// fmt.Println("document ok", string(body), idx)
// // Sync document
// _, err = ds.client.SyncDocuments(spaceId, []string{idx})
// if err != nil {
// return fmt.Errorf("sync document %d: %w", i+1, err)
// }
// }
// return nil
// }
// // type IndexReqOption struct{
// // UserId string
// // UserName string
// // DocName string
// // }
// // // Index implements the Indexer.Index method.
// // func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// // if len(req.Documents) == 0 {
// // return nil
// // }
// // req.Options
// // userid := ""
// // usernmae := ""
// // for _, doc := range req.Documents {
// // if v, ok := doc.Metadata["user_id"]; ok {
// // if str, isString := v.(string); isString {
// // userid = str
// // }
// // }
// // if v, ok := doc.Metadata["username"]; ok {
// // if str, isString := v.(string); isString {
// // usernmae = str
// // }
// // }
// // }
// // // Create knowledge space.
// // spaceReq := SpaceRequest{
// // Name: userid,
// // VectorType: "KnowledgeGraph",
// // DomainType: "Normal",
// // Desc: usernmae,
// // Owner: userid,
// // }
// // resp, err := ds.client.AddSpace(spaceReq)
// // if err != nil {
// // return fmt.Errorf("add space: %w", err)
// // }
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // body, _ := io.ReadAll(resp.Body)
// // return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
// // }
// // fmt.Println("space ok")
// // spaceId := userid
// // // Index each document.
// // for i, doc := range req.Documents {
// // docName := ""
// // if v, ok := doc.Metadata["doc_name"]; ok {
// // if str, isString := v.(string); isString {
// // docName = str
// // } else {
// // // Generate random docName.
// // var err error
// // docName, err = generateRandomDocName(8)
// // if err != nil {
// // return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// // }
// // }
// // } else {
// // // Generate random docName.
// // var err error
// // docName, err = generateRandomDocName(8)
// // if err != nil {
// // return fmt.Errorf("generate random docName for document %d: %w", i+1, err)
// // }
// // }
// // fmt.Println("docName: ", docName)
// // // Add document.
// // var sb strings.Builder
// // for _, p := range doc.Content {
// // if p.IsText() {
// // sb.WriteString(p.Text)
// // }
// // }
// // text := sb.String()
// // fmt.Println("text: ",text)
// // docReq := DocumentRequest{
// // DocName: docName,
// // Source: "api",
// // DocType: "TEXT",
// // Content: text,
// // Labels: "",
// // // Questions: []string{},
// // Metadata: doc.Metadata,
// // }
// // resp, err := ds.client.AddDocument(spaceId, docReq)
// // if err != nil {
// // return fmt.Errorf("add document %d: %w", i+1, err)
// // }
// // body, _ := io.ReadAll(resp.Body)
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
// // }
// // ok, idx, err := ParseJSONResponse(body)
// // if err != nil {
// // return fmt.Errorf("ParseJSONResponse %d: %w", i+1, err)
// // }
// // if !ok{
// // return fmt.Errorf("ParseJSONResponse body %d: %w", i+1, err)
// // }
// // fmt.Println("document ok",string(body),idx)
// // ok ,err =ds.client.SyncDocuments(spaceId,[]string{idx})
// // if err != nil{
// // return err
// // }
// // }
// // return nil
// // }
// // // RetrieverOptions for Knowledge retrieval.
// // type RetrieverOptions struct {
// // Count int `json:"count,omitempty"` // Max documents to retrieve.
// // MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// // }
// // // Assuming ai.Part has a Text() method or Text field to get the string content
// // func partsToString(parts []*ai.Part) string {
// // var texts []string
// // for _, part := range parts {
// // // Adjust this based on the actual ai.Part structure
// // // If ai.Part has a Text() method:
// // texts = append(texts, part.Text)
// // // OR if ai.Part has a Text field:
// // // texts = append(texts, part.Text)
// // }
// // return strings.Join(texts, " ")
// // }
// // // Retrieve implements the Retriever.Retrieve method.
// // func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// // // count := 3
// // // metricTypeStr := "L2"
// // // if req.Options != nil {
// // // ropt, ok := req.Options.(*RetrieverOptions)
// // // if !ok {
// // // return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// // // }
// // // if ropt.Count > 0 {
// // // count = ropt.Count
// // // }
// // // if ropt.MetricType != "" {
// // // metricTypeStr = ropt.MetricType
// // // }
// // // }
// // queryContent := partsToString(req.Query.Content)
// // // Format query for retrieval.
// // queryText := fmt.Sprintf("Search for: %s", queryContent)
// // username := "Alice" // Default, override if metadata available.
// // if req.Query.Metadata != nil {
// // if uname, ok := req.Query.Metadata["username"].(string); ok {
// // username = uname
// // }
// // }
// // // Prepare request for chat completions endpoint.
// // url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL)
// // chatReq := struct {
// // ConvUID string `json:"conv_uid"`
// // UserInput string `json:"user_input"`
// // UserName string `json:"user_name"`
// // ChatMode string `json:"chat_mode"`
// // AppCode string `json:"app_code"`
// // Temperature float32 `json:"temperature"`
// // MaxNewTokens int `json:"max_new_tokens"`
// // SelectParam string `json:"select_param"`
// // ModelName string `json:"model_name"`
// // Incremental bool `json:"incremental"`
// // SysCode string `json:"sys_code"`
// // PromptCode string `json:"prompt_code"`
// // ExtInfo map[string]interface{} `json:"ext_info"`
// // }{
// // ConvUID: "",
// // UserInput: queryText,
// // UserName: username,
// // ChatMode: "",
// // AppCode: "",
// // Temperature: 0.5,
// // MaxNewTokens: 4000,
// // SelectParam: "",
// // ModelName: ds.modelName,
// // Incremental: false,
// // SysCode: "",
// // PromptCode: "",
// // ExtInfo: map[string]interface{}{
// // "space_id": ds.spaceID,
// // //"k": count,
// // },
// // }
// // body, err := json.Marshal(chatReq)
// // if err != nil {
// // return nil, fmt.Errorf("marshal chat request: %w", err)
// // }
// // httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
// // if err != nil {
// // return nil, fmt.Errorf("create chat request: %w", err)
// // }
// // httpReq.Header.Set("Accept", "application/json")
// // httpReq.Header.Set("Content-Type", "application/json")
// // client := &http.Client{}
// // resp, err := client.Do(httpReq)
// // if err != nil {
// // return nil, fmt.Errorf("send chat request: %w", err)
// // }
// // defer resp.Body.Close()
// // if resp.StatusCode != http.StatusOK {
// // body, _ := io.ReadAll(resp.Body)
// // return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
// // }
// // // Parse response
// // var chatResp struct {
// // Success bool `json:"success"`
// // Data struct {
// // Answer []struct {
// // Content string `json:"content"`
// // DocID int `json:"doc_id"`
// // Score float64 `json:"score"`
// // Metadata map[string]interface{} `json:"metadata_map"`
// // } `json:"answer"`
// // } `json:"data"`
// // }
// // if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
// // return nil, fmt.Errorf("decode chat response: %w", err)
// // }
// // var docs []*ai.Document
// // for _, doc := range chatResp.Data.Answer {
// // metadata := doc.Metadata
// // if metadata == nil {
// // metadata = make(map[string]interface{})
// // }
// // // Ensure metadata includes user_id and username.
// // if _, ok := metadata["user_id"]; !ok {
// // metadata["user_id"] = "user123"
// // }
// // if _, ok := metadata["username"]; !ok {
// // metadata["username"] = username
// // }
// // aiDoc := ai.DocumentFromText(doc.Content, metadata)
// // docs = append(docs, aiDoc)
// // }
// // return &ai.RetrieverResponse{
// // Documents: docs,
// // }, nil
// // }
// // ChatRequest 定义请求结构体,匹配单元测试的 curl 请求
// type ChatRequest struct {
// Model string `json:"model"`
// Messages string `json:"messages"`
// Temperature float64 `json:"temperature"`
// TopP float64 `json:"top_p"`
// TopK int `json:"top_k"`
// N int `json:"n"`
// MaxTokens int `json:"max_tokens"`
// Stream bool `json:"stream"`
// RepetitionPenalty float64 `json:"repetition_penalty"`
// FrequencyPenalty float64 `json:"frequency_penalty"`
// PresencePenalty float64 `json:"presence_penalty"`
// ChatMode string `json:"chat_mode"`
// ChatParam string `json:"chat_param"`
// EnableVis bool `json:"enable_vis"`
// }
// // ChatResponse 定义响应结构体,匹配单元测试的 API 响应
// type ChatResponse struct {
// ID string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Model string `json:"model"`
// Choices []struct {
// Index int `json:"index"`
// Message struct {
// Role string `json:"role"`
// Content string `json:"content"`
// ReasoningContent interface{} `json:"reasoning_content"`
// } `json:"message"`
// FinishReason interface{} `json:"finish_reason"`
// } `json:"choices"`
// Usage struct {
// PromptTokens int `json:"prompt_tokens"`
// TotalTokens int `json:"total_tokens"`
// CompletionTokens int `json:"completion_tokens"`
// } `json:"usage"`
// }
// // Assuming ai.Part has a Text() method or Text field to get the string content
// func partsToString(parts []*ai.Part) string {
// var texts []string
// for _, part := range parts {
// // Adjust this based on the actual ai.Part structure
// // If ai.Part has a Text() method:
// texts = append(texts, part.Text)
// // OR if ai.Part has a Text field:
// // texts = append(texts, part.Text)
// }
// return strings.Join(texts, " ")
// }
// // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// // Format query for retrieval.
// queryContent := partsToString(req.Query.Content)
// queryText := fmt.Sprintf("Search for: %s", queryContent)
// if req.Query.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil
// return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// }
// for k, v := range req.Query.Metadata {
// fmt.Println("k", k, "v", v)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// // Prepare request for chat completions endpoint.
// url := fmt.Sprintf("%s/api/v2/chat/completions", ds.client.BaseURL)
// chatReq := ChatRequest{
// Model: ds.modelName,
// Messages: queryText,
// Temperature: 0.7,
// TopP: 1,
// TopK: -1,
// N: 1,
// MaxTokens: 0,
// Stream: false,
// RepetitionPenalty: 1,
// FrequencyPenalty: 0,
// PresencePenalty: 0,
// ChatMode: "chat_knowledge",
// ChatParam: userId,
// EnableVis: true,
// }
// body, err := json.Marshal(chatReq)
// if err != nil {
// return nil, fmt.Errorf("marshal chat request: %w", err)
// }
// httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
// if err != nil {
// return nil, fmt.Errorf("create chat request: %w", err)
// }
// httpReq.Header.Set("Accept", "application/json")
// httpReq.Header.Set("Content-Type", "application/json")
// client := &http.Client{}
// resp, err := client.Do(httpReq)
// if err != nil {
// return nil, fmt.Errorf("send chat request: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// body, _ := io.ReadAll(resp.Body)
// return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
// }
// // Parse response
// var chatResp ChatResponse
// if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
// body, _ := io.ReadAll(resp.Body)
// return nil, fmt.Errorf("decode chat response: %w, raw response: %s", err, string(body))
// }
// // Convert response to ai.Document
// var docs []*ai.Document
// if len(chatResp.Choices) > 0 {
// content := chatResp.Choices[0].Message.Content
// metadata := map[string]interface{}{
// util.UserIdKey: userId,
// util.UserNameKey: userName,
// }
// aiDoc := ai.DocumentFromText(content, metadata)
// docs = append(docs, aiDoc)
// }
// return &ai.RetrieverResponse{
// Documents: docs,
// }, nil
// }
......@@ -30,6 +30,7 @@ import (
"github.com/firebase/genkit/go/genkit"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/rs/zerolog/log"
"github.com/wade-liwei/agentchat/util"
)
......@@ -71,6 +72,7 @@ func (m *Milvus) Name() string {
// Init initializes the Milvus plugin.
func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
log.Info().Str("method", "Milvus.Init").Msg("Initializing Milvus plugin")
if m == nil {
m = &Milvus{}
}
......@@ -79,7 +81,10 @@ func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
defer m.mu.Unlock()
defer func() {
if err != nil {
log.Error().Err(err).Str("method", "Milvus.Init").Msg("Initialization failed")
err = fmt.Errorf("milvus.Init: %w", err)
} else {
log.Info().Str("method", "Milvus.Init").Msg("Initialization successful")
}
}()
......@@ -140,29 +145,41 @@ type CollectionConfig struct {
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
log.Info().
Str("method", "DefineIndexerAndRetriever").
Str("collection", cfg.Collection).
Int("dimension", cfg.Dimension).
Msg("Defining indexer and retriever")
if cfg.Embedder == nil {
log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Embedder required")
return nil, nil, errors.New("milvus: Embedder required")
}
if cfg.Collection == "" {
log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Collection name required")
return nil, nil, errors.New("milvus: collection name required")
}
if cfg.Dimension <= 0 {
log.Error().Str("method", "DefineIndexerAndRetriever").Int("dimension", cfg.Dimension).Msg("Dimension must be positive")
return nil, nil, errors.New("milvus: dimension must be positive")
}
m := genkit.LookupPlugin(g, provider)
if m == nil {
log.Error().Str("method", "DefineIndexerAndRetriever").Msg("Milvus plugin not found")
return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
}
milvus := m.(*Milvus)
ds, err := milvus.newDocStore(ctx, &cfg)
if err != nil {
log.Error().Err(err).Str("method", "DefineIndexerAndRetriever").Str("collection", cfg.Collection).Msg("Failed to create doc store")
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
log.Info().Str("method", "DefineIndexerAndRetriever").Str("collection", cfg.Collection).Msg("Indexer and retriever defined successfully")
return indexer, retriever, nil
}
......@@ -175,226 +192,23 @@ type docStore struct {
embedderOptions map[string]interface{}
}
// // newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil {
// return nil, errors.New("milvus.Init not called")
// }
// // Check/create collection.
// exists, err := m.client.HasCollection(ctx, cfg.Collection)
// if err != nil {
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
// }
// if !exists {
// // Define schema.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: idField,
// DataType: entity.FieldTypeInt64,
// PrimaryKey: true,
// AutoID: true,
// },
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "65535",
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
// if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// }
// // Create HNSW index.
// index, err := entity.NewIndexHNSW(
// entity.L2,
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
// if err != nil {
// return nil, fmt.Errorf("failed to create index: %v", err)
// }
// }
// // Load collection.
// err = m.client.LoadCollection(ctx, cfg.Collection, false)
// if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
// }
// // Convert EmbedderOptions to map[string]interface{}.
// var embedderOptions map[string]interface{}
// if cfg.EmbedderOptions != nil {
// opts, ok := cfg.EmbedderOptions.(map[string]interface{})
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
// }
// embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// return &docStore{
// client: m.client,
// collection: cfg.Collection,
// dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil {
// return nil, errors.New("milvus.Init not called")
// }
// // Check/create collection.
// exists, err := m.client.HasCollection(ctx, cfg.Collection)
// if err != nil {
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
// }
// if !exists {
// // Define schema with textField as primary key for unique constraint.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// // {
// // Name: idField, // Optional non-primary ID field
// // DataType: entity.FieldTypeInt64,
// // //AutoID: true,
// // // No PrimaryKey or AutoID, as textField is the primary key
// // },
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{
// "max_length": "65535", // Maximum length for VARCHAR, adjust if needed
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// // Alternative: Remove idField if not needed
// /*
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{
// "max_length": "65535",
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// */
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
// if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// }
// // Create HNSW index.
// index, err := entity.NewIndexHNSW(
// entity.L2,
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
// if err != nil {
// return nil, fmt.Errorf("failed to create index: %v", err)
// }
// }
// // Load collection.
// err = m.client.LoadCollection(ctx, cfg.Collection, false)
// if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
// }
// // Convert EmbedderOptions to map[string]interface{}.
// var embedderOptions map[string]interface{}
// if cfg.EmbedderOptions != nil {
// opts, ok := cfg.EmbedderOptions.(map[string]interface{})
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
// }
// embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// return &docStore{
// client: m.client,
// collection: cfg.Collection,
// dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
log.Info().
Str("method", "Milvus.newDocStore").
Str("collection", cfg.Collection).
Int("dimension", cfg.Dimension).
Msg("Creating new doc store")
if m.client == nil {
log.Error().Str("method", "Milvus.newDocStore").Msg("Milvus client not initialized")
return nil, errors.New("milvus.Init not called")
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to check collection")
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
......@@ -412,9 +226,9 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
PrimaryKey: true,
TypeParams: map[string]string{
"max_length": "65535", // Maximum length for VARCHAR
"max_length": "65535",
},
},
{
......@@ -425,14 +239,14 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
Name: "user_id",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128", // Reasonable length for user_id
"max_length": "128",
},
},
{
Name: "username",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "128", // Reasonable length for username
"max_length": "128",
},
},
},
......@@ -440,6 +254,7 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to create collection")
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
......@@ -450,11 +265,13 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
96, // efConstruction
)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to create HNSW index")
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
log.Error().Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msgf("Failed to create index: %s",err.Error())
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
......@@ -462,6 +279,7 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
log.Error().Err(err).Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Failed to load collection")
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
......@@ -470,6 +288,10 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
log.Error().
Str("method", "Milvus.newDocStore").
Str("type", fmt.Sprintf("%T", cfg.EmbedderOptions)).
Msg("EmbedderOptions must be a map[string]interface{}")
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
......@@ -477,6 +299,7 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
embedderOptions = make(map[string]interface{})
}
log.Info().Str("method", "Milvus.newDocStore").Str("collection", cfg.Collection).Msg("Doc store created successfully")
return &docStore{
client: m.client,
collection: cfg.Collection,
......@@ -488,21 +311,34 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
// Indexer returns the indexer for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
return genkit.LookupIndexer(g, provider, collection)
log.Info().Str("method", "Indexer").Str("collection", collection).Msg("Looking up indexer")
indexer := genkit.LookupIndexer(g, provider, collection)
if indexer == nil {
log.Warn().Str("method", "Indexer").Str("collection", collection).Msg("Indexer not found")
}
return indexer
}
// Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection)
log.Info().Str("method", "Retriever").Str("collection", collection).Msg("Looking up retriever")
retriever := genkit.LookupRetriever(g, provider, collection)
if retriever == nil {
log.Warn().Str("method", "Retriever").Str("collection", collection).Msg("Retriever not found")
}
return retriever
}
/*
更新 删除 很少用到;
*/
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
log.Info().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("documents", len(req.Documents)).
Msg("Starting index operation")
if len(req.Documents) == 0 {
log.Debug().Str("method", "docStore.Index").Str("collection", ds.collection).Msg("No documents to index")
return nil
}
......@@ -513,11 +349,18 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Index").Str("collection", ds.collection).Msg("Embedding failed")
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
log.Error().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("embeddings", len(eres.Embeddings)).
Int("documents", len(req.Documents)).
Msg("Mismatch in embedding and document count")
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
......@@ -527,17 +370,19 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
doc := req.Documents[i]
if doc.Metadata == nil {
// If ok, we don't use the User struct since the requirement is to error on non-nil
return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Document metadata is nil")
return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", doc.Metadata)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := doc.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Missing username in metadata")
return fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := doc.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Index").Int("index", i).Msg("Missing user_id in metadata")
return fmt.Errorf("req.Query.Metadata must provide user_id key")
}
......@@ -555,26 +400,39 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = emb.Embedding // []float32
row["vector"] = emb.Embedding
row["text"] = text
row["user_id"] = userId
row["username"] = userName
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
row["metadata"] = metadata
rows = append(rows, row)
// Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q,userId=%s,username=%s,metadata=%v\n", i, len(emb.Embedding), text, userId, userName, metadata)
log.Debug().
Str("method", "docStore.Index").
Int("index", i).
Str("collection", ds.collection).
Int("vector_length", len(emb.Embedding)).
Str("text", text).
Str("user_id", userId).
Str("username", userName).
Interface("metadata", metadata).
Msg("Prepared row for insertion")
}
// Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
log.Info().
Str("method", "docStore.Index").
Str("collection", ds.collection).
Int("rows", len(rows)).
Msg("Inserting rows into Milvus")
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Index").Str("collection", ds.collection).Msg("Failed to insert rows")
return fmt.Errorf("milvus insert rows failed: %w", err)
}
log.Info().Str("method", "docStore.Index").Str("collection", ds.collection).Int("rows", len(rows)).Msg("Index operation completed successfully")
return nil
}
......@@ -584,173 +442,49 @@ type RetrieverOptions struct {
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
log.Info().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Msg("Starting retrieve operation")
// if req.Query.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil
// return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// }
if req.Query.Metadata == nil {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Query metadata is nil")
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
// // Extract username and user_id from req.Query.Metadata
// userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Missing username in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Missing user_id in metadata")
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
// count := 3 // Default.
// metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// // Map string metric type to entity.MetricType.
// var metricType entity.MetricType
// switch metricTypeStr {
// case "L2":
// metricType = entity.L2
// case "IP":
// metricType = entity.IP
// default:
// return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
// }
// // Embed query.
// ereq := &ai.EmbedRequest{
// Input: []*ai.Document{req.Query},
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
// }
// if len(eres.Embeddings) == 0 {
// return nil, errors.New("no embeddings generated for query")
// }
// queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// // Create search parameters.
// searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
// if err != nil {
// return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// }
// // Perform vector search to get IDs, text, and metadata.
// results, err := ds.client.Search(
// ctx,
// ds.collection,
// []string{}, // partitions
// "", // expr (TODO: add metadata filter if needed)
// []string{textField, metadataField}, // Output fields: text and metadata
// []entity.Vector{queryVector},
// vectorField,
// metricType,
// count,
// searchParams,
// )
// if err != nil {
// return nil, fmt.Errorf("milvus search failed: %v", err)
// }
// // Process search results.
// var docs []*ai.Document
// for _, result := range results {
// // Find text and metadata columns in search results.
// var textCol, metaCol entity.Column
// for _, col := range result.Fields {
// if col.Name() == textField {
// textCol = col
// }
// if col.Name() == metadataField {
// metaCol = col
// }
// }
// // Ensure text column exists.
// if textCol == nil {
// return nil, fmt.Errorf("text column %s not found in search results", textField)
// }
// // Iterate over rows (assuming columns have same length).
// for i := 0; i < result.ResultCount; i++ {
// // Get text value.
// text, err := textCol.GetAsString(i)
// if err != nil {
// fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
// continue
// }
// // Get metadata value (optional, as metadata column may be missing).
// var metadata map[string]interface{}
// if metaCol != nil {
// metaStr, err := metaCol.GetAsString(i)
// if err == nil && metaStr != "" {
// if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
// fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
// continue
// }
// } else if err != nil {
// fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
// }
// }
// // Print text and metadata in a format similar to insertion debug log.
// // fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata)
// // Create document.
// doc := ai.DocumentFromText(text, metadata)
// docs = append(docs, doc)
// }
// }
// return &ai.RetrieverResponse{
// Documents: docs,
// }, nil
// }
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
if req.Query.Metadata == nil {
return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
// Extract username and user_id from req.Query.Metadata
userName, ok := req.Query.Metadata[util.UserNameKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide username key")
}
userId, ok := req.Query.Metadata[util.UserIdKey].(string)
if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
}
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
log.Error().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("options_type", fmt.Sprintf("%T", req.Options)).
Msg("Invalid options type")
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
......@@ -760,6 +494,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
case "IP":
metricType = entity.IP
default:
log.Error().Str("method", "docStore.Retrieve").Str("metric_type", metricTypeStr).Msg("Unsupported metric type")
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
......@@ -770,9 +505,11 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Embedding failed")
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
log.Error().Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("No embeddings generated")
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
......@@ -780,19 +517,28 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Failed to create HNSW search parameters")
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Define filter expression for user_id
expr := fmt.Sprintf("user_id == %q", userId)
log.Debug().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("user_id", userId).
Str("metric_type", metricTypeStr).
Int("count", count).
Msg("Performing vector search")
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
expr, // Filter by user_id
[]string{textField, metadataField}, // Output fields: text and metadata
[]string{},
expr,
[]string{textField, metadataField},
[]entity.Vector{queryVector},
vectorField,
metricType,
......@@ -800,6 +546,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
searchParams,
)
if err != nil {
log.Error().Err(err).Str("method", "docStore.Retrieve").Str("collection", ds.collection).Msg("Search failed")
return nil, fmt.Errorf("milvus search failed: %v", err)
}
......@@ -819,33 +566,53 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Ensure text column exists.
if textCol == nil {
log.Error().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Str("field", textField).
Msg("Text column not found in search results")
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// Iterate over rows (assuming columns have same length).
// Iterate over rows.
for i := 0; i < result.ResultCount; i++ {
// Get text value.
text, err := textCol.GetAsString(i)
if err != nil {
fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to parse text")
continue
}
// Get metadata value (optional, as metadata column may be missing).
// Get metadata value (optional).
var metadata map[string]interface{}
if metaCol != nil {
metaStr, err := metaCol.GetAsString(i)
if err == nil && metaStr != "" {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to parse metadata")
continue
}
} else if err != nil {
fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
log.Error().
Err(err).
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Msg("Failed to get metadata string")
}
}
// Ensure metadata includes user_id and username from query
// Ensure metadata includes user_id and username from query.
if metadata == nil {
metadata = make(map[string]interface{})
}
......@@ -855,10 +622,570 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
log.Debug().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("index", i).
Str("text", text).
Interface("metadata", metadata).
Msg("Processed search result")
}
}
log.Info().
Str("method", "docStore.Retrieve").
Str("collection", ds.collection).
Int("documents", len(docs)).
Msg("Retrieve operation completed successfully")
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// // Package milvus provides a Genkit plugin for Milvus vector database using milvus-sdk-go.
// package milvus
// import (
// "context"
// "encoding/json"
// "errors"
// "fmt"
// "os"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/milvus-io/milvus-sdk-go/v2/client"
// "github.com/milvus-io/milvus-sdk-go/v2/entity"
// "github.com/wade-liwei/agentchat/util"
// )
// // The provider used in the registry.
// const provider = "milvus"
// // Field names for Milvus schema.
// const (
// idField = "id"
// vectorField = "vector"
// textField = "text"
// metadataField = "metadata"
// )
// // Milvus holds configuration for the plugin.
// type Milvus struct {
// // Milvus server address (host:port, e.g., "localhost:19530").
// // Defaults to MILVUS_ADDRESS environment variable.
// Addr string
// // Username for authentication.
// // Defaults to MILVUS_USERNAME.
// Username string
// // Password for authentication.
// // Defaults to MILVUS_PASSWORD.
// Password string
// // Token for authentication (alternative to username/password).
// // Defaults to MILVUS_TOKEN.
// Token string
// client client.Client // Milvus client.
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the plugin name.
// func (m *Milvus) Name() string {
// return provider
// }
// // Init initializes the Milvus plugin.
// func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
// if m == nil {
// m = &Milvus{}
// }
// m.mu.Lock()
// defer m.mu.Unlock()
// defer func() {
// if err != nil {
// err = fmt.Errorf("milvus.Init: %w", err)
// }
// }()
// if m.initted {
// return errors.New("plugin already initialized")
// }
// // Load configuration.
// addr := m.Addr
// if addr == "" {
// addr = os.Getenv("MILVUS_ADDRESS")
// }
// if addr == "" {
// return errors.New("milvus address required")
// }
// username := m.Username
// if username == "" {
// username = os.Getenv("MILVUS_USERNAME")
// }
// password := m.Password
// if password == "" {
// password = os.Getenv("MILVUS_PASSWORD")
// }
// token := m.Token
// if token == "" {
// token = os.Getenv("MILVUS_TOKEN")
// }
// // Initialize Milvus client.
// config := client.Config{
// Address: addr,
// Username: username,
// Password: password,
// APIKey: token,
// }
// client, err := client.NewClient(ctx, config)
// if err != nil {
// return fmt.Errorf("failed to initialize Milvus client: %v", err)
// }
// m.client = client
// m.initted = true
// return nil
// }
// // CollectionConfig holds configuration for an indexer/retriever pair.
// type CollectionConfig struct {
// // Milvus collection name. Must not be empty.
// Collection string
// // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
// Dimension int
// // Embedder for generating vectors.
// Embedder ai.Embedder
// // Embedder options.
// EmbedderOptions any
// }
// // DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
// func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
// if cfg.Embedder == nil {
// return nil, nil, errors.New("milvus: Embedder required")
// }
// if cfg.Collection == "" {
// return nil, nil, errors.New("milvus: collection name required")
// }
// if cfg.Dimension <= 0 {
// return nil, nil, errors.New("milvus: dimension must be positive")
// }
// m := genkit.LookupPlugin(g, provider)
// if m == nil {
// return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
// }
// milvus := m.(*Milvus)
// ds, err := milvus.newDocStore(ctx, &cfg)
// if err != nil {
// return nil, nil, err
// }
// indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
// retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
// return indexer, retriever, nil
// }
// // docStore defines an Indexer and a Retriever.
// type docStore struct {
// client client.Client
// collection string
// dimension int
// embedder ai.Embedder
// embedderOptions map[string]interface{}
// }
// // newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil {
// return nil, errors.New("milvus.Init not called")
// }
// // Check/create collection.
// exists, err := m.client.HasCollection(ctx, cfg.Collection)
// if err != nil {
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
// }
// if !exists {
// // Define schema with textField as primary key, plus user_id and username fields.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{
// "max_length": "65535", // Maximum length for VARCHAR
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// {
// Name: "user_id",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for user_id
// },
// },
// {
// Name: "username",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for username
// },
// },
// },
// }
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
// if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// }
// // Create HNSW index for vectorField.
// index, err := entity.NewIndexHNSW(
// entity.L2,
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
// if err != nil {
// return nil, fmt.Errorf("failed to create index: %v", err)
// }
// }
// // Load collection.
// err = m.client.LoadCollection(ctx, cfg.Collection, false)
// if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
// }
// // Convert EmbedderOptions to map[string]interface{}.
// var embedderOptions map[string]interface{}
// if cfg.EmbedderOptions != nil {
// opts, ok := cfg.EmbedderOptions.(map[string]interface{})
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
// }
// embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// return &docStore{
// client: m.client,
// collection: cfg.Collection,
// dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// // Indexer returns the indexer for a collection.
// func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
// return genkit.LookupIndexer(g, provider, collection)
// }
// // Retriever returns the retriever for a collection.
// func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
// return genkit.LookupRetriever(g, provider, collection)
// }
// /*
// 更新 删除 很少用到;
// */
// // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 {
// return nil
// }
// // Embed documents.
// ereq := &ai.EmbedRequest{
// Input: req.Documents,
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return fmt.Errorf("milvus index embedding failed: %w", err)
// }
// // Validate embedding count matches document count.
// if len(eres.Embeddings) != len(req.Documents) {
// return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
// }
// // Prepare row-based data.
// var rows []interface{}
// for i, emb := range eres.Embeddings {
// doc := req.Documents[i]
// if doc.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil
// return fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := doc.Metadata[util.UserNameKey].(string)
// if !ok {
// return fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := doc.Metadata[util.UserIdKey].(string)
// if !ok {
// return fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// metadata := doc.Metadata
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// // Create row with explicit metadata field.
// row := make(map[string]interface{})
// row["vector"] = emb.Embedding // []float32
// row["text"] = text
// row["user_id"] = userId
// row["username"] = userName
// row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
// rows = append(rows, row)
// // Debug: Log row contents.
// fmt.Printf("Row %d: vector_len=%d, text=%q,userId=%s,username=%s,metadata=%v\n", i, len(emb.Embedding), text, userId, userName, metadata)
// }
// // Debug: Log total rows.
// fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// // Insert rows into Milvus.
// _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
// if err != nil {
// return fmt.Errorf("milvus insert rows failed: %w", err)
// }
// return nil
// }
// // RetrieverOptions for Milvus retrieval.
// type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve.
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// }
// // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// if req.Query.Metadata == nil {
// return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// count := 3 // Default.
// metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// // Map string metric type to entity.MetricType.
// var metricType entity.MetricType
// switch metricTypeStr {
// case "L2":
// metricType = entity.L2
// case "IP":
// metricType = entity.IP
// default:
// return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
// }
// // Embed query.
// ereq := &ai.EmbedRequest{
// Input: []*ai.Document{req.Query},
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
// }
// if len(eres.Embeddings) == 0 {
// return nil, errors.New("no embeddings generated for query")
// }
// queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// // Create search parameters.
// searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
// if err != nil {
// return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// }
// // Define filter expression for user_id
// expr := fmt.Sprintf("user_id == %q", userId)
// // Perform vector search to get IDs, text, and metadata.
// results, err := ds.client.Search(
// ctx,
// ds.collection,
// []string{}, // partitions
// expr, // Filter by user_id
// []string{textField, metadataField}, // Output fields: text and metadata
// []entity.Vector{queryVector},
// vectorField,
// metricType,
// count,
// searchParams,
// )
// if err != nil {
// return nil, fmt.Errorf("milvus search failed: %v", err)
// }
// // Process search results.
// var docs []*ai.Document
// for _, result := range results {
// // Find text and metadata columns in search results.
// var textCol, metaCol entity.Column
// for _, col := range result.Fields {
// if col.Name() == textField {
// textCol = col
// }
// if col.Name() == metadataField {
// metaCol = col
// }
// }
// // Ensure text column exists.
// if textCol == nil {
// return nil, fmt.Errorf("text column %s not found in search results", textField)
// }
// // Iterate over rows (assuming columns have same length).
// for i := 0; i < result.ResultCount; i++ {
// // Get text value.
// text, err := textCol.GetAsString(i)
// if err != nil {
// fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
// continue
// }
// // Get metadata value (optional, as metadata column may be missing).
// var metadata map[string]interface{}
// if metaCol != nil {
// metaStr, err := metaCol.GetAsString(i)
// if err == nil && metaStr != "" {
// if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
// fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
// continue
// }
// } else if err != nil {
// fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
// }
// }
// // Ensure metadata includes user_id and username from query
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// metadata[util.UserIdKey] = userId
// metadata[util.UserNameKey] = userName
// // Create document.
// doc := ai.DocumentFromText(text, metadata)
// docs = append(docs, doc)
// }
// }
// return &ai.RetrieverResponse{
// Documents: docs,
// }, nil
// }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment