Commit 654f62a4 authored by Wade's avatar Wade

add user field

parent 8864a9d6
...@@ -23,13 +23,13 @@ import ( ...@@ -23,13 +23,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/milvus-io/milvus-sdk-go/v2/client" "github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity" "github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/wade-liwei/agentchat/util"
) )
// The provider used in the registry. // The provider used in the registry.
...@@ -266,375 +266,479 @@ type docStore struct { ...@@ -266,375 +266,479 @@ type docStore struct {
// } // }
// newDocStore creates a docStore. // newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) { // func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
if m.client == nil { // if m.client == nil {
return nil, errors.New("milvus.Init not called") // return nil, errors.New("milvus.Init not called")
} // }
// Check/create collection. // // Check/create collection.
exists, err := m.client.HasCollection(ctx, cfg.Collection) // exists, err := m.client.HasCollection(ctx, cfg.Collection)
if err != nil { // if err != nil {
return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err) // return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
} // }
if !exists { // if !exists {
// Define schema with textField as primary key for unique constraint. // // Define schema with textField as primary key for unique constraint.
schema := &entity.Schema{ // schema := &entity.Schema{
CollectionName: cfg.Collection, // CollectionName: cfg.Collection,
Fields: []*entity.Field{ // Fields: []*entity.Field{
// { // // {
// Name: idField, // Optional non-primary ID field // // Name: idField, // Optional non-primary ID field
// DataType: entity.FieldTypeInt64, // // DataType: entity.FieldTypeInt64,
// //AutoID: true, // // //AutoID: true,
// // No PrimaryKey or AutoID, as textField is the primary key // // // No PrimaryKey or AutoID, as textField is the primary key
// }, // // },
{ // {
Name: vectorField, // Name: vectorField,
DataType: entity.FieldTypeFloatVector, // DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{ // TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension), // "dim": fmt.Sprintf("%d", cfg.Dimension),
}, // },
}, // },
{ // {
Name: textField, // Name: textField,
DataType: entity.FieldTypeVarChar, // DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field // PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{ // TypeParams: map[string]string{
"max_length": "65535", // Maximum length for VARCHAR, adjust if needed // "max_length": "65535", // Maximum length for VARCHAR, adjust if needed
}, // },
}, // },
{ // {
Name: metadataField, // Name: metadataField,
DataType: entity.FieldTypeJSON, // DataType: entity.FieldTypeJSON,
}, // },
}, // },
} // }
// Alternative: Remove idField if not needed // // Alternative: Remove idField if not needed
/* // /*
schema := &entity.Schema{ // schema := &entity.Schema{
CollectionName: cfg.Collection, // CollectionName: cfg.Collection,
Fields: []*entity.Field{ // Fields: []*entity.Field{
{ // {
Name: vectorField, // Name: vectorField,
DataType: entity.FieldTypeFloatVector, // DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{ // TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", cfg.Dimension), // "dim": fmt.Sprintf("%d", cfg.Dimension),
}, // },
}, // },
{ // {
Name: textField, // Name: textField,
DataType: entity.FieldTypeVarChar, // DataType: entity.FieldTypeVarChar,
PrimaryKey: true, // Enforce unique constraint on text field // PrimaryKey: true, // Enforce unique constraint on text field
TypeParams: map[string]string{ // TypeParams: map[string]string{
"max_length": "65535", // "max_length": "65535",
}, // },
}, // },
{ // {
Name: metadataField, // Name: metadataField,
DataType: entity.FieldTypeJSON, // DataType: entity.FieldTypeJSON,
}, // },
}, // },
} // }
*/ // */
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index. // err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
index, err := entity.NewIndexHNSW( // if err != nil {
entity.L2, // return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
8, // M // }
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false) // // Create HNSW index.
if err != nil { // index, err := entity.NewIndexHNSW(
return nil, fmt.Errorf("failed to create index: %v", err) // entity.L2,
} // 8, // M
} // 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// Load collection. // err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
err = m.client.LoadCollection(ctx, cfg.Collection, false) // if err != nil {
if err != nil { // return nil, fmt.Errorf("failed to create index: %v", err)
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err) // }
} // }
// Convert EmbedderOptions to map[string]interface{}. // // Load collection.
var embedderOptions map[string]interface{} // err = m.client.LoadCollection(ctx, cfg.Collection, false)
if cfg.EmbedderOptions != nil { // if err != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{}) // return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
if !ok { // }
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{ // // Convert EmbedderOptions to map[string]interface{}.
client: m.client, // var embedderOptions map[string]interface{}
collection: cfg.Collection, // if cfg.EmbedderOptions != nil {
dimension: cfg.Dimension, // opts, ok := cfg.EmbedderOptions.(map[string]interface{})
embedder: cfg.Embedder, // if !ok {
embedderOptions: embedderOptions, // return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}, nil // }
} // embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// Indexer returns the indexer for a collection. // return &docStore{
func Indexer(g *genkit.Genkit, collection string) ai.Indexer { // client: m.client,
return genkit.LookupIndexer(g, provider, collection) // collection: cfg.Collection,
} // dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection)
}
/*
更新 删除 很少用到;
*/
// Index implements the Indexer.Index method. // package graphrag
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 {
return nil
}
// Embed documents. // import (
ereq := &ai.EmbedRequest{ // "context"
Input: req.Documents, // "fmt"
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count. // "github.com/milvus-io/milvus-sdk-go/v2/entity"
if len(eres.Embeddings) != len(req.Documents) { // "github.com/pkg/errors"
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents)) // )
}
// Prepare row-based data. // // newDocStore creates a docStore.
var rows []interface{} // func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
for i, emb := range eres.Embeddings { // if m.client == nil {
doc := req.Documents[i] // return nil, errors.New("milvus.Init not called")
var sb strings.Builder // }
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Create row with explicit metadata field. // // Check/create collection.
row := make(map[string]interface{}) // exists, err := m.client.HasCollection(ctx, cfg.Collection)
row["vector"] = emb.Embedding // []float32 // if err != nil {
row["text"] = text // return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map // }
rows = append(rows, row) // if !exists {
// // Define schema with textField as primary key, plus user_id and username fields.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// PrimaryKey: true, // Enforce unique constraint on text field
// TypeParams: map[string]string{
// "max_length": "65535", // Maximum length for VARCHAR
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// {
// Name: "user_id",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for user_id
// },
// },
// {
// Name: "username",
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "128", // Reasonable length for username
// },
// },
// },
// }
// Debug: Log row contents. // err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata) // if err != nil {
} // return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// }
// Debug: Log total rows. // // Create HNSW index for vectorField.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection) // index, err := entity.NewIndexHNSW(
// entity.L2,
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// Insert rows into Milvus. // err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows) // if err != nil {
if err != nil { // return nil, fmt.Errorf("failed to create index: %v", err)
return fmt.Errorf("milvus insert rows failed: %w", err) // }
} // }
return nil // // Load collection.
} // err = m.client.LoadCollection(ctx, cfg.Collection, false)
// if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
// }
// // Convert EmbedderOptions to map[string]interface{}.
// var embedderOptions map[string]interface{}
// if cfg.EmbedderOptions != nil {
// opts, ok := cfg.EmbedderOptions.(map[string]interface{})
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
// }
// embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// return &docStore{
// client: m.client,
// collection: cfg.Collection,
// dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// // Indexer returns the indexer for a collection.
// func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
// return genkit.LookupIndexer(g, provider, collection)
// }
// // Retriever returns the retriever for a collection.
// func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
// return genkit.LookupRetriever(g, provider, collection)
// }
// /*
// 更新 删除 很少用到;
// */
// // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 {
// return nil
// }
// // Embed documents.
// ereq := &ai.EmbedRequest{
// Input: req.Documents,
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return fmt.Errorf("milvus index embedding failed: %w", err)
// }
// // Validate embedding count matches document count.
// if len(eres.Embeddings) != len(req.Documents) {
// return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
// }
// // Prepare row-based data.
// var rows []interface{}
// for i, emb := range eres.Embeddings {
// doc := req.Documents[i]
// if doc.Metadata == nil {
// // If ok, we don't use the User struct since the requirement is to error on non-nil
// return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// }
// // Extract username and user_id from req.Query.Metadata
// userName, ok := doc.Metadata[util.UserNameKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// }
// userId, ok := doc.Metadata[util.UserIdKey].(string)
// if !ok {
// return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// }
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// metadata := doc.Metadata
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// // Create row with explicit metadata field.
// row := make(map[string]interface{})
// row["vector"] = emb.Embedding // []float32
// row["text"] = text
// row["user_id"] = userId
// row["username"] = userName
// row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
// rows = append(rows, row)
// // Debug: Log row contents.
// fmt.Printf("Row %d: vector_len=%d, text=%q,userId=%s,username=%s,metadata=%v\n", i, len(emb.Embedding), text,userId,userName metadata)
// }
// // Debug: Log total rows.
// fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// // Insert rows into Milvus.
// _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
// if err != nil {
// return fmt.Errorf("milvus insert rows failed: %w", err)
// }
// return nil
// }
// // RetrieverOptions for Milvus retrieval. // // RetrieverOptions for Milvus retrieval.
// type RetrieverOptions struct { // type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve. // Count int `json:"count,omitempty"` // Max documents to retrieve.
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). // MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// } // }
// // Retrieve implements the Retriever.Retrieve method. // // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { // func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// count := 3 // Default.
// metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// // Map string metric type to entity.MetricType.
// var metricType entity.MetricType
// switch metricTypeStr {
// case "L2":
// metricType = entity.L2
// case "IP":
// metricType = entity.IP
// default:
// return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
// }
// // Embed query. // if req.Query.Metadata == nil {
// ereq := &ai.EmbedRequest{ // // If ok, we don't use the User struct since the requirement is to error on non-nil
// Input: []*ai.Document{req.Query}, // return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Options)
// Options: ds.embedderOptions, // }
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
// }
// if len(eres.Embeddings) == 0 {
// return nil, errors.New("no embeddings generated for query")
// }
// queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// // Create search parameters.
// searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
// if err != nil {
// return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// }
// // Perform vector search to get IDs. // // Extract username and user_id from req.Query.Metadata
// results, err := ds.client.Search( // userName, ok := req.Query.Metadata[util.UserNameKey].(string)
// ctx, // if !ok {
// ds.collection, // return nil, fmt.Errorf("req.Query.Metadata must provide username key")
// []string{}, // partitions // }
// "", // expr (TODO: add metadata filter if needed) // userId, ok := req.Query.Metadata[util.UserIdKey].(string)
// []string{}, // Only need IDs for now, no output fields // if !ok {
// []entity.Vector{queryVector}, // return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// vectorField, // }
// metricType,
// count,
// searchParams,
// )
// if err != nil {
// return nil, fmt.Errorf("milvus search failed: %v", err)
// }
// // Extract IDs from search results.
// var ids []int64
// for _, result := range results {
// for i := 0; i < result.ResultCount; i++ {
// id, err := result.IDs.GetAsInt64(i)
// if err != nil {
// continue
// }
// ids = append(ids, id)
// }
// }
// if len(ids) == 0 {
// return &ai.RetrieverResponse{
// Documents: []*ai.Document{},
// }, nil
// }
// // Construct filter expression for Query (e.g., "id IN [id1, id2, ...]"). // count := 3 // Default.
// filterExpr := fmt.Sprintf("id IN [%s]", joinInt64s(ids, ",")) // metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// // Perform query to get text and metadata. // // Map string metric type to entity.MetricType.
// queryOptions := []client.SearchQueryOptionFunc{ // var metricType entity.MetricType
// client.WithLimit(int64(count)), // switch metricTypeStr {
// } // case "L2":
// // Note: Consistency level omitted due to undefined WithQueryConsistencyLevel. // metricType = entity.L2
// // If WithConsistencyLevel is supported for Query in your SDK, uncomment below: // case "IP":
// // queryOptions = append(queryOptions, client.WithConsistencyLevel(entity.ConsistencyBounded)) // metricType = entity.IP
// default:
// queryResults, err := ds.client.Query( // return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
// ctx, // }
// ds.collection,
// []string{}, // partitions
// filterExpr, // filter by IDs
// []string{textField, metadataField}, // output fields
// queryOptions...,
// )
// if err != nil {
// return nil, fmt.Errorf("milvus query failed: %v", err)
// }
// // Process query results. // // Embed query.
// var docs []*ai.Document // ereq := &ai.EmbedRequest{
// // Find text and metadata columns in query results. // Input: []*ai.Document{req.Query},
// var textCol, metaCol entity.Column // Options: ds.embedderOptions,
// for _, col := range queryResults { // }
// if col.Name() == textField { // eres, err := ds.embedder.Embed(ctx, ereq)
// textCol = col // if err != nil {
// } // return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
// if col.Name() == metadataField { // }
// metaCol = col // if len(eres.Embeddings) == 0 {
// } // return nil, errors.New("no embeddings generated for query")
// } // }
// queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// // Ensure text column exists. // // Create search parameters.
// if textCol == nil { // searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
// return nil, fmt.Errorf("text column %s not found in query results", textField) // if err != nil {
// } // return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// }
// // Iterate over rows (assuming columns have same length). // // Perform vector search to get IDs, text, and metadata.
// for i := 0; i < textCol.Len(); i++ { // results, err := ds.client.Search(
// // Get text value. // ctx,
// text, err := textCol.GetAsString(i) // ds.collection,
// if err != nil { // []string{}, // partitions
// fmt.Printf("Failed to parse text at index %d: %v\n", i, err) // "", // expr (TODO: add metadata filter if needed)
// continue // []string{textField, metadataField}, // Output fields: text and metadata
// } // []entity.Vector{queryVector},
// vectorField,
// metricType,
// count,
// searchParams,
// )
// if err != nil {
// return nil, fmt.Errorf("milvus search failed: %v", err)
// }
// // Get metadata value (optional, as metadata column may be missing). // // Process search results.
// var metadata map[string]interface{} // var docs []*ai.Document
// if metaCol != nil { // for _, result := range results {
// metaStr, err := metaCol.GetAsString(i) // // Find text and metadata columns in search results.
// if err == nil && metaStr != "" { // var textCol, metaCol entity.Column
// if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil { // for _, col := range result.Fields {
// fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err) // if col.Name() == textField {
// continue // textCol = col
// } // }
// } else if err != nil { // if col.Name() == metadataField {
// fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err) // metaCol = col
// } // }
// } // }
// // Print text and metadata in a format similar to insertion debug log. // // Ensure text column exists.
// fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata) // if textCol == nil {
// return nil, fmt.Errorf("text column %s not found in search results", textField)
// }
// // Create document. // // Iterate over rows (assuming columns have same length).
// doc := ai.DocumentFromText(text, metadata) // for i := 0; i < result.ResultCount; i++ {
// docs = append(docs, doc) // // Get text value.
// } // text, err := textCol.GetAsString(i)
// if err != nil {
// fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
// continue
// }
// // Get metadata value (optional, as metadata column may be missing).
// var metadata map[string]interface{}
// if metaCol != nil {
// metaStr, err := metaCol.GetAsString(i)
// if err == nil && metaStr != "" {
// if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
// fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
// continue
// }
// } else if err != nil {
// fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
// }
// }
// // Print text and metadata in a format similar to insertion debug log.
// // fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata)
// // Create document.
// doc := ai.DocumentFromText(text, metadata)
// docs = append(docs, doc)
// }
// }
// return &ai.RetrieverResponse{ // return &ai.RetrieverResponse{
// Documents: docs, // Documents: docs,
// }, nil // }, nil
// } // }
// // joinInt64s converts a slice of int64 to a comma-separated string.
// func joinInt64s(ids []int64, sep string) string {
// if len(ids) == 0 {
// return ""
// }
// strs := make([]string, len(ids))
// for i, id := range ids {
// strs[i] = fmt.Sprintf("%d", id)
// }
// return strings.Join(strs, sep)
// }
// RetrieverOptions for Milvus retrieval. // RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct { type RetrieverOptions struct {
...@@ -642,123 +746,147 @@ type RetrieverOptions struct { ...@@ -642,123 +746,147 @@ type RetrieverOptions struct {
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP"). MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
} }
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
count := 3 // Default.
metricTypeStr := "L2"
if req.Options != nil {
ropt, ok := req.Options.(*RetrieverOptions)
if !ok {
return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
}
if ropt.Count > 0 {
count = ropt.Count
}
if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
}
}
// Map string metric type to entity.MetricType.
var metricType entity.MetricType
switch metricTypeStr {
case "L2":
metricType = entity.L2
case "IP":
metricType = entity.IP
default:
return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
"", // expr (TODO: add metadata filter if needed)
[]string{textField, metadataField}, // Output fields: text and metadata
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists. // Retrieve implements the Retriever.Retrieve method.
if textCol == nil { func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
return nil, fmt.Errorf("text column %s not found in search results", textField) if req.Query.Metadata == nil {
} return nil, fmt.Errorf("req.Query.Metadata must be not nil, got type %T", req.Query.Metadata)
}
// Iterate over rows (assuming columns have same length).
for i := 0; i < result.ResultCount; i++ { // Extract username and user_id from req.Query.Metadata
// Get text value. userName, ok := req.Query.Metadata[util.UserNameKey].(string)
text, err := textCol.GetAsString(i) if !ok {
if err != nil { return nil, fmt.Errorf("req.Query.Metadata must provide username key")
fmt.Printf("Failed to parse text at index %d: %v\n", i, err) }
continue userId, ok := req.Query.Metadata[util.UserIdKey].(string)
} if !ok {
return nil, fmt.Errorf("req.Query.Metadata must provide user_id key")
// Get metadata value (optional, as metadata column may be missing). }
var metadata map[string]interface{}
if metaCol != nil { count := 3 // Default.
metaStr, err := metaCol.GetAsString(i) metricTypeStr := "L2"
if err == nil && metaStr != "" { if req.Options != nil {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil { ropt, ok := req.Options.(*RetrieverOptions)
fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err) if !ok {
continue return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
} }
} else if err != nil { if ropt.Count > 0 {
fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err) count = ropt.Count
} }
} if ropt.MetricType != "" {
metricTypeStr = ropt.MetricType
// Print text and metadata in a format similar to insertion debug log. }
// fmt.Printf("Row %d: text=%q, metadata=%v\n", i, text, metadata) }
// Create document. // Map string metric type to entity.MetricType.
doc := ai.DocumentFromText(text, metadata) var metricType entity.MetricType
docs = append(docs, doc) switch metricTypeStr {
} case "L2":
} metricType = entity.L2
case "IP":
return &ai.RetrieverResponse{ metricType = entity.IP
Documents: docs, default:
}, nil return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
}
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
}
if len(eres.Embeddings) == 0 {
return nil, errors.New("no embeddings generated for query")
}
queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// Create search parameters.
searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
if err != nil {
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
}
// Define filter expression for user_id
expr := fmt.Sprintf("user_id == %q", userId)
// Perform vector search to get IDs, text, and metadata.
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions
expr, // Filter by user_id
[]string{textField, metadataField}, // Output fields: text and metadata
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
// Process search results.
var docs []*ai.Document
for _, result := range results {
// Find text and metadata columns in search results.
var textCol, metaCol entity.Column
for _, col := range result.Fields {
if col.Name() == textField {
textCol = col
}
if col.Name() == metadataField {
metaCol = col
}
}
// Ensure text column exists.
if textCol == nil {
return nil, fmt.Errorf("text column %s not found in search results", textField)
}
// Iterate over rows (assuming columns have same length).
for i := 0; i < result.ResultCount; i++ {
// Get text value.
text, err := textCol.GetAsString(i)
if err != nil {
fmt.Printf("Failed to parse text at index %d: %v\n", i, err)
continue
}
// Get metadata value (optional, as metadata column may be missing).
var metadata map[string]interface{}
if metaCol != nil {
metaStr, err := metaCol.GetAsString(i)
if err == nil && metaStr != "" {
if err := json.Unmarshal([]byte(metaStr), &metadata); err != nil {
fmt.Printf("Failed to parse metadata at index %d: %v\n", i, err)
continue
}
} else if err != nil {
fmt.Printf("Failed to get metadata string at index %d: %v\n", i, err)
}
}
// Ensure metadata includes user_id and username from query
if metadata == nil {
metadata = make(map[string]interface{})
}
metadata["user_id"] = userId
metadata["username"] = userName
// Create document.
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment