Commit 96fa2c8b authored by Wade's avatar Wade

rag ok

parent a99c5c92
...@@ -132,35 +132,44 @@ curl -X POST http://localhost:8000/indexGraph \ ...@@ -132,35 +132,44 @@ curl -X POST http://localhost:8000/indexGraph \
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "Document indexed successfully"}
d1 := ai.DocumentFromText("Paris is the capital of France", nil)
d2 := ai.DocumentFromText("USA is the largest importer of coffee", nil)
d3 := ai.DocumentFromText("Water exists in 3 states - solid, liquid and gas", nil)
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "Paris is the capital of France"}
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "USA is the largest importer of coffee"}
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "Water exists in 3 states - solid, liquid and gas", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "USA is the largest importer of coffee"}
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "Document indexed successfully"}
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"strings"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
...@@ -47,6 +48,16 @@ type GraphInput struct { ...@@ -47,6 +48,16 @@ type GraphInput struct {
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
const simpleQaPromptTemplate = `
You're a helpful agent that answers the user's common questions based on the context provided.
Here is the user's query: {{query}}
Here is the context you should use: {{context}}
Please provide the best answer you can.
`
func main() { func main() {
ctx := context.Background() ctx := context.Background()
...@@ -79,12 +90,12 @@ func main() { ...@@ -79,12 +90,12 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
m := ds.DefineModel(g, // m := ds.DefineModel(g,
deepseek.ModelDefinition{ // deepseek.ModelDefinition{
Name: "deepseek-chat", // Choose an appropriate model // Name: "deepseek-chat", // Choose an appropriate model
Type: "chat", // Must be chat for tool support // Type: "chat", // Must be chat for tool support
}, // },
nil) // nil)
embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001")
if embedder == nil { if embedder == nil {
...@@ -157,6 +168,16 @@ func main() { ...@@ -157,6 +168,16 @@ func main() {
return fmt.Sprintf("Document indexed successfully, docname %s", resDocName), nil return fmt.Sprintf("Document indexed successfully, docname %s", resDocName), nil
}) })
simpleQaPrompt, err := genkit.DefinePrompt(g, "simpleQaPrompt",
ai.WithModelName("googleai/gemini-2.0-flash"),
ai.WithPrompt(simpleQaPromptTemplate),
ai.WithInputType(simpleQaPromptInput{}),
ai.WithOutputFormat(ai.OutputFormatText),
)
if err != nil {
log.Fatal(err)
}
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) {
...@@ -175,29 +196,28 @@ func main() { ...@@ -175,29 +196,28 @@ func main() {
} }
for _, d := range response.Documents { for _, d := range response.Documents {
fmt.Println("d.Content[0].Text",d.Content[0].Text) fmt.Println("d.Content[0].Text", d.Content[0].Text)
} }
return "",nil var sb strings.Builder
for _, d := range response.Documents {
sb.WriteString(d.Content[0].Text)
resp, err := genkit.Generate(ctx, g, sb.WriteByte('\n')
ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`))
if err != nil {
fmt.Println(err.Error())
return "", err
} }
fmt.Println("resp.Text()", resp.Text()) promptInput := &simpleQaPromptInput{
Query: input.Content,
Context: sb.String(),
}
resp, err := simpleQaPrompt.Execute(ctx, ai.WithInput(promptInput))
if err != nil { if err != nil {
return "", err return "", err
} }
return resp.Text(), nil
text := resp.Text() //ai.WithPrompt(promptInput))
return text, nil //ai.WithPrompt(`Tell silly short jokes about apple`)
}) })
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5 // 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
...@@ -225,3 +245,8 @@ func main() { ...@@ -225,3 +245,8 @@ func main() {
log.Fatalf("Server failed: %v", err) log.Fatalf("Server failed: %v", err)
} }
} }
type simpleQaPromptInput struct {
Query string `json:"query"`
Context string `json:"context"`
}
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
// func TestGenerateEmbedding(t *testing.T) // func TestGenerateEmbedding(t *testing.T)
func TestGenerateEmbedding(t *testing.T){ func TestGenerateEmbedding(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Initialize Genkit with Google AI plugin // Initialize Genkit with Google AI plugin
......
...@@ -265,138 +265,125 @@ type docStore struct { ...@@ -265,138 +265,125 @@ type docStore struct {
// }, nil // }, nil
// } // }
// newDocStore creates a docStore. // newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
if m.client == nil { if m.client == nil {
return nil, errors.New("milvus.Init not called") return nil, errors.New("milvus.Init not called")
} }
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema with textField as primary key for unique constraint.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
// {
// Name: idField, // Optional non-primary ID field
// DataType: entity.FieldTypeInt64,
// //AutoID: true,
// // No PrimaryKey or AutoID, as textField is the primary key
// },
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{
"max_length": "65535", // Maximum length for VARCHAR, adjust if needed
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
// Alternative: Remove idField if not needed
/*
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
*/
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
}
// Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
}
if !exists {
// Define schema with textField as primary key for unique constraint.
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
// {
// Name: idField, // Optional non-primary ID field
// DataType: entity.FieldTypeInt64,
// //AutoID: true,
// // No PrimaryKey or AutoID, as textField is the primary key
// },
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{
"max_length": "65535", // Maximum length for VARCHAR, adjust if needed
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
// Alternative: Remove idField if not needed
/*
schema := &entity.Schema{
CollectionName: cfg.Collection,
Fields: []*entity.Field{
{
Name: vectorField,
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension),
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
DataType: entity.FieldTypeJSON,
},
},
}
*/
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: embedderOptions,
}, nil
}
// Indexer returns the indexer for a collection. // Indexer returns the indexer for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer { func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
...@@ -472,10 +459,6 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -472,10 +459,6 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
return nil return nil
} }
// // RetrieverOptions for Milvus retrieval. // // RetrieverOptions for Milvus retrieval.
// type RetrieverOptions struct { // type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve. // Count int `json:"count,omitempty"` // Max documents to retrieve.
...@@ -653,138 +636,129 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { ...@@ -653,138 +636,129 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// return strings.Join(strs, sep) // return strings.Join(strs, sep)
// } // }
// RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct {
Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
"", // expr (TODO: add metadata filter if needed)
[]string{textField, metadataField}, // Output fields: text and metadata
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists.
if textCol == nil {
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// RetrieverOptions for Milvus retrieval. // Iterate over rows (assuming columns have same length).
type RetrieverOptions struct { for i := 0; i < result.ResultCount; i++ {
Count int `json:"count,omitempty"` // Max documents to retrieve. // Get text value.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). text, err := textCol.GetAsString(i)
} if err != nil {
fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
continue
}
// Retrieve implements the Retriever.Retrieve method. // Get metadata value (optional, as metadata column may be missing).
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { var metadata map[string]interface{}
count := 3 // Default. if metaCol != nil {
metricTypeStr := "L2" metaStr, err := metaCol.GetAsString(i)
if req.Options != nil { if err == nil && metaStr != "" {
ropt, ok := req.Options.(*RetrieverOptions) if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
if !ok { fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{}) continue
} }
if ropt.Count > 0 { } else if err != nil {
count = ropt.Count fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
} }
if ropt.MetricType != "" { }
metricTypeStr = ropt.MetricType
} // Print text and metadata in a format similar to insertion debug log.
} // fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata)
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
"", // expr (TODO: add metadata filter if needed)
[]string{textField, metadataField}, // Output fields: text and metadata
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists.
if textCol == nil {
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// Iterate over rows (assuming columns have same length).
for i := 0; i < result.ResultCount; i++ {
// Get text value.
text, err := textCol.GetAsString(i)
if err != nil {
fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
continue
}
// Get metadata value (optional, as metadata column may be missing).
var metadata map[string]interface{}
if metaCol != nil {
metaStr, err := metaCol.GetAsString(i)
if err == nil && metaStr != "" {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
continue
}
} else if err != nil {
fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
}
}
// Print text and metadata in a format similar to insertion debug log.
// fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata)
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment