Commit c8150637 authored by Wade's avatar Wade

milvus write ok

parent d7a0aa40
......@@ -107,7 +107,7 @@ func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
token = os.Getenv("MILVUS_TOKEN")
}
// Initialize Milvus client (inspired by examples/simple/main.go).
// Initialize Milvus client.
config := client.Config{
Address: addr,
Username: username,
......@@ -119,12 +119,6 @@ func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
return fmt.Errorf("failed to initialize Milvus client: %v", err)
}
// Verify connection.
// if err := client.Connect(ctx); err != nil {
// return fmt.Errorf("failed to connect to Milvus: %v", err)
// }
m.client = client
m.initted = true
return nil
......@@ -164,7 +158,7 @@ func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg Collec
if err != nil {
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
return indexer, retriever, nil
......@@ -176,11 +170,10 @@ type docStore struct {
collection string
dimension int
embedder ai.Embedder
// embedderOptions any
embedderOptions map[string]interface{}
}
// newDocStore creates a docStore (inspired by examples/simple/main.go).
// newDocStore creates a docStore.
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
if m.client == nil {
return nil, errors.New("milvus.Init not called")
......@@ -210,8 +203,11 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
},
},
{
Name: textField,
DataType: entity.FieldTypeVarChar,
Name: textField,
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "65535",
},
},
{
Name: metadataField,
......@@ -225,35 +221,46 @@ func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docSt
return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
}
// Create HNSW index (per examples/index/main.go).
index,err :=entity.NewIndexHNSW(
entity.L2, // Distance metric.
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW %v", err)
// Create HNSW index.
index, err := entity.NewIndexHNSW(
entity.L2,
8, // M
96, // efConstruction
)
if err != nil {
return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
}
err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
if err != nil {
return nil, fmt.Errorf("failed to create index for collection %q: %v", cfg.Collection, err)
return nil, fmt.Errorf("failed to create index: %v", err)
}
}
// Load collection (per examples/simple/main.go).
// Load collection.
err = m.client.LoadCollection(ctx, cfg.Collection, false)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
}
// Convert EmbedderOptions to map[string]interface{}.
var embedderOptions map[string]interface{}
if cfg.EmbedderOptions != nil {
opts, ok := cfg.EmbedderOptions.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
}
embedderOptions = opts
} else {
embedderOptions = make(map[string]interface{})
}
return &docStore{
client: m.client,
collection: cfg.Collection,
dimension: cfg.Dimension,
embedder: cfg.Embedder,
embedderOptions: cfg.EmbedderOptions,
embedderOptions: embedderOptions,
}, nil
}
......@@ -267,7 +274,6 @@ func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection)
}
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 {
......@@ -284,6 +290,11 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Validate embedding count matches document count.
if len(eres.Embeddings) != len(req.Documents) {
return fmt.Errorf("mismatch: got %d embeddings for %d documents", len(eres.Embeddings), len(req.Documents))
}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
......@@ -300,17 +311,21 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
metadata = make(map[string]interface{})
}
// Create row as map[string]interface{}
// Create row with explicit metadata field.
row := make(map[string]interface{})
row["vector"] = entity.FloatVector(emb.Embedding) // Vector field
row["text"] = text // Text field
for k, v := range metadata {
row[k] = v // Add metadata fields dynamically
}
row["vector"] = emb.Embedding // []float32
row["text"] = text
row["metadata"] = metadata // Explicitly set metadata as JSON-compatible map
rows = append(rows, row)
// Debug: Log row contents.
fmt.Printf("Row %d: vector_len=%d, text=%q, metadata=%v\n", i, len(emb.Embedding), text, metadata)
}
// Insert rows into Milvus
// Debug: Log total rows.
fmt.Printf("Inserting %d rows into collection %q\n", len(rows), ds.collection)
// Insert rows into Milvus.
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
return fmt.Errorf("milvus insert rows failed: %w", err)
......@@ -325,8 +340,7 @@ type RetrieverOptions struct {
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// Retrieve implements the Retriever.Retrieve method (inspired by examples/search/main.go).
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
count := 3 // Default.
metricTypeStr := "L2"
......@@ -356,8 +370,8 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Embed query.
ereq := &ai.EmbedRequest{
Input: []*ai.Document{req.Query}, // Fixed: Use req.Document instead of req.Query.Content
Options: ds.embedderOptions,
Input: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
......@@ -378,15 +392,15 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
results, err := ds.client.Search(
ctx,
ds.collection,
[]string{}, // partitions: empty for all partitions
"", // expr: empty for no filtering
[]string{}, // partitions
"", // expr
[]string{textField, metadataField}, // output fields
[]entity.Vector{queryVector}, // vectors
vectorField, // vector field
metricType, // metric type
count, // topK
searchParams, // search parameters
) // opts: omitted
[]entity.Vector{queryVector},
vectorField,
metricType,
count,
searchParams,
)
if err != nil {
return nil, fmt.Errorf("milvus search failed: %v", err)
}
......@@ -396,22 +410,11 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
for _, result := range results {
for i := 0; i < result.ResultCount; i++ {
textCol := result.Fields.GetColumn(textField)
text, err := textCol.GetAsString(i)
if err != nil {
continue
}
//metadataCol := result.Fields.GetColumn(metadataField)
// metadataBytes, err := metadataCol.Get(i)
// if err != nil {
// continue
// }
var metadata map[string]interface{}
// if len(metadataBytes) > 0 {
// if err := json.Unmarshal(metadataBytes.([]byte), &metadata); err != nil {
// continue
// }
// }
doc := ai.DocumentFromText(text, metadata)
docs = append(docs, doc)
}
......@@ -421,3 +424,463 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Documents: docs,
}, nil
}
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// // Package milvus provides a Genkit plugin for Milvus vector database using milvus-sdk-go.
// package milvus
// import (
// "context"
// "errors"
// "fmt"
// "os"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/milvus-io/milvus-sdk-go/v2/client"
// "github.com/milvus-io/milvus-sdk-go/v2/entity"
// )
// // The provider used in the registry.
// const provider = "milvus"
// // Field names for Milvus schema.
// const (
// idField = "id"
// vectorField = "vector"
// textField = "text"
// metadataField = "metadata"
// )
// // Milvus holds configuration for the plugin.
// type Milvus struct {
// // Milvus server address (host:port, e.g., "localhost:19530").
// // Defaults to MILVUS_ADDRESS environment variable.
// Addr string
// // Username for authentication.
// // Defaults to MILVUS_USERNAME.
// Username string
// // Password for authentication.
// // Defaults to MILVUS_PASSWORD.
// Password string
// // Token for authentication (alternative to username/password).
// // Defaults to MILVUS_TOKEN.
// Token string
// client client.Client // Milvus client.
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the plugin name.
// func (m *Milvus) Name() string {
// return provider
// }
// // Init initializes the Milvus plugin.
// func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
// if m == nil {
// m = &Milvus{}
// }
// m.mu.Lock()
// defer m.mu.Unlock()
// defer func() {
// if err != nil {
// err = fmt.Errorf("milvus.Init: %w", err)
// }
// }()
// if m.initted {
// return errors.New("plugin already initialized")
// }
// // Load configuration.
// addr := m.Addr
// if addr == "" {
// addr = os.Getenv("MILVUS_ADDRESS")
// }
// if addr == "" {
// return errors.New("milvus address required")
// }
// username := m.Username
// if username == "" {
// username = os.Getenv("MILVUS_USERNAME")
// }
// password := m.Password
// if password == "" {
// password = os.Getenv("MILVUS_PASSWORD")
// }
// token := m.Token
// if token == "" {
// token = os.Getenv("MILVUS_TOKEN")
// }
// // Initialize Milvus client.
// config := client.Config{
// Address: addr,
// Username: username,
// Password: password,
// APIKey: token,
// }
// client, err := client.NewClient(ctx, config)
// if err != nil {
// return fmt.Errorf("failed to initialize Milvus client: %v", err)
// }
// m.client = client
// m.initted = true
// return nil
// }
// // CollectionConfig holds configuration for an indexer/retriever pair.
// type CollectionConfig struct {
// // Milvus collection name. Must not be empty.
// Collection string
// // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
// Dimension int
// // Embedder for generating vectors.
// Embedder ai.Embedder
// // Embedder options.
// EmbedderOptions any
// }
// // DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
// func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
// if cfg.Embedder == nil {
// return nil, nil, errors.New("milvus: Embedder required")
// }
// if cfg.Collection == "" {
// return nil, nil, errors.New("milvus: collection name required")
// }
// if cfg.Dimension <= 0 {
// return nil, nil, errors.New("milvus: dimension must be positive")
// }
// m := genkit.LookupPlugin(g, provider)
// if m == nil {
// return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
// }
// milvus := m.(*Milvus)
// ds, err := milvus.newDocStore(ctx, &cfg)
// if err != nil {
// return nil, nil, err
// }
// indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
// retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
// return indexer, retriever, nil
// }
// // docStore defines an Indexer and a Retriever.
// type docStore struct {
// client client.Client
// collection string
// dimension int
// embedder ai.Embedder
// embedderOptions map[string]interface{}
// }
// // newDocStore creates a docStore.
// func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
// if m.client == nil {
// return nil, errors.New("milvus.Init not called")
// }
// // Check/create collection.
// exists, err := m.client.HasCollection(ctx, cfg.Collection)
// if err != nil {
// return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
// }
// if !exists {
// // Define schema.
// schema := &entity.Schema{
// CollectionName: cfg.Collection,
// Fields: []*entity.Field{
// {
// Name: idField,
// DataType: entity.FieldTypeInt64,
// PrimaryKey: true,
// AutoID: true,
// },
// {
// Name: vectorField,
// DataType: entity.FieldTypeFloatVector,
// TypeParams: map[string]string{
// "dim": fmt.Sprintf("%d", cfg.Dimension),
// },
// },
// {
// Name: textField,
// DataType: entity.FieldTypeVarChar,
// TypeParams: map[string]string{
// "max_length": "65535", // Set max_length for VarChar
// },
// },
// {
// Name: metadataField,
// DataType: entity.FieldTypeJSON,
// },
// },
// }
// err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
// if err != nil {
// return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
// }
// // Create HNSW index.
// index, err := entity.NewIndexHNSW(
// entity.L2, // Distance metric.
// 8, // M
// 96, // efConstruction
// )
// if err != nil {
// return nil, fmt.Errorf("entity.NewIndexHNSW: %v", err)
// }
// err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
// if err != nil {
// return nil, fmt.Errorf("failed to create index for collection %q: %v", cfg.Collection, err)
// }
// }
// // Load collection.
// err = m.client.LoadCollection(ctx, cfg.Collection, false)
// if err != nil {
// return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
// }
// // Convert EmbedderOptions to map[string]interface{}.
// var embedderOptions map[string]interface{}
// if cfg.EmbedderOptions != nil {
// opts, ok := cfg.EmbedderOptions.(map[string]interface{})
// if !ok {
// return nil, fmt.Errorf("EmbedderOptions must be a map[string]interface{}, got %T", cfg.EmbedderOptions)
// }
// embedderOptions = opts
// } else {
// embedderOptions = make(map[string]interface{})
// }
// return &docStore{
// client: m.client,
// collection: cfg.Collection,
// dimension: cfg.Dimension,
// embedder: cfg.Embedder,
// embedderOptions: embedderOptions,
// }, nil
// }
// // Indexer returns the indexer for a collection.
// func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
// return genkit.LookupIndexer(g, provider, collection)
// }
// // Retriever returns the retriever for a collection.
// func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
// return genkit.LookupRetriever(g, provider, collection)
// }
// // Index implements the Indexer.Index method.
// func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// if len(req.Documents) == 0 {
// return nil
// }
// // Embed documents.
// ereq := &ai.EmbedRequest{
// Input: req.Documents,
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return fmt.Errorf("milvus index embedding failed: %w", err)
// }
// // Prepare row-based data.
// var rows []interface{}
// for i, emb := range eres.Embeddings {
// doc := req.Documents[i]
// var sb strings.Builder
// for _, p := range doc.Content {
// if p.IsText() {
// sb.WriteString(p.Text)
// }
// }
// text := sb.String()
// metadata := doc.Metadata
// if metadata == nil {
// metadata = make(map[string]interface{})
// }
// // Create row as map[string]interface{}
// row := make(map[string]interface{})
// row["vector"] = emb.Embedding//entity.FloatVector(emb.Embedding) // Vector field
// row["text"] = text // Text field
// for k, v := range metadata {
// row[k] = v // Add metadata fields dynamically
// }
// rows = append(rows, row)
// }
// // Insert rows into Milvus
// _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
// if err != nil {
// return fmt.Errorf("milvus insert rows failed: %w", err)
// }
// return nil
// }
// // RetrieverOptions for Milvus retrieval.
// type RetrieverOptions struct {
// Count int `json:"count,omitempty"` // Max documents to retrieve.
// MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
// }
// // Retrieve implements the Retriever.Retrieve method.
// func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// count := 3 // Default.
// metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// // Map string metric type to entity.MetricType.
// var metricType entity.MetricType
// switch metricTypeStr {
// case "L2":
// metricType = entity.L2
// case "IP":
// metricType = entity.IP
// default:
// return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
// }
// // Embed query.
// ereq := &ai.EmbedRequest{
// Input: []*ai.Document{req.Query},
// Options: ds.embedderOptions,
// }
// eres, err := ds.embedder.Embed(ctx, ereq)
// if err != nil {
// return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
// }
// if len(eres.Embeddings) == 0 {
// return nil, errors.New("no embeddings generated for query")
// }
// queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)
// // Create search parameters.
// searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
// if err != nil {
// return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
// }
// // Perform search.
// results, err := ds.client.Search(
// ctx,
// ds.collection,
// []string{}, // partitions: empty for all partitions
// "", // expr: empty for no filtering
// []string{textField, metadataField}, // output fields
// []entity.Vector{queryVector}, // vectors
// vectorField, // vector field
// metricType, // metric type
// count, // topK
// searchParams, // search parameters
// )
// if err != nil {
// return nil, fmt.Errorf("milvus search failed: %v", err)
// }
// // Process results.
// var docs []*ai.Document
// for _, result := range results {
// for i := 0; i < result.ResultCount; i++ {
// textCol := result.Fields.GetColumn(textField)
// text, err := textCol.GetAsString(i)
// if err != nil {
// continue
// }
// var metadata map[string]interface{}
// doc := ai.DocumentFromText(text, metadata)
// docs = append(docs, doc)
// }
// }
// return &ai.RetrieverResponse{
// Documents: docs,
// }, nil
// }
......@@ -66,16 +66,22 @@ func dropCollection(ctx context.Context, client client.Client, collectionName st
func TestMilvusIntegration(t *testing.T) {
ctx := context.Background()
// Initialize Genkit
g := &genkit.Genkit{}
// Initialize Milvus plugin
m := &Milvus{
ms := Milvus{
Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
}
err := m.Init(ctx, g)
// Initialize Genkit with Milvus plugin
g, err := genkit.Init(ctx, genkit.WithPlugins(&ms))
if err != nil {
t.Fatalf("Milvus.Init failed: %v", err)
t.Fatalf("genkit.Init failed: %v", err)
}
// Get the Milvus client for cleanup
m, ok := genkit.LookupPlugin(g, provider).(*Milvus)
if !ok {
t.Fatalf("Failed to lookup Milvus plugin")
}
defer m.client.Close()
......@@ -87,7 +93,7 @@ func TestMilvusIntegration(t *testing.T) {
Collection: collectionName,
Dimension: 768, // Match mock embedder dimension
Embedder: &MockEmbedder{},
EmbedderOptions: map[string]interface{}{},
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
}
// Define indexer and retriever
......@@ -116,7 +122,7 @@ func TestMilvusIntegration(t *testing.T) {
},
}
req := &ai.IndexerRequest{Documents: documents}
err = indexer.Index(ctx, req)
err := indexer.Index(ctx, req)
if err != nil {
t.Fatalf("Index failed: %v", err)
}
......@@ -152,7 +158,7 @@ func TestMilvusIntegration(t *testing.T) {
t.Run("Empty Index", func(t *testing.T) {
req := &ai.IndexerRequest{Documents: []*ai.Document{}}
err = indexer.Index(ctx, req)
err := indexer.Index(ctx, req)
assert.NoError(t, err, "Indexing empty documents should succeed")
})
......@@ -161,9 +167,219 @@ func TestMilvusIntegration(t *testing.T) {
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{MetricType: "INVALID"},
}
_, err = retriever.Retrieve(ctx, queryReq)
_, err := retriever.Retrieve(ctx, queryReq)
assert.Error(t, err, "Should fail with invalid metric type")
assert.Contains(t, err.Error(), "unsupported metric type")
})
t.Run("Invalid Embedder Options", func(t *testing.T) {
// Test with invalid EmbedderOptions type
invalidCfg := CollectionConfig{
Collection: collectionName + "_invalid",
Dimension: 768,
Embedder: &MockEmbedder{},
EmbedderOptions: "not-a-map", // Invalid type
}
_, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg)
assert.Error(t, err, "Should fail with invalid EmbedderOptions type")
assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
})
}
// // Copyright 2025 Google LLC
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package milvus
// import (
// "context"
// "fmt"
// "strings"
// "testing"
// "time"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/milvus-io/milvus-sdk-go/v2/client"
// "github.com/test-go/testify/assert"
// )
// // MockEmbedder is a mock implementation of ai.Embedder.
// type MockEmbedder struct{}
// func (m *MockEmbedder) Name() string {
// return "mock-embedder"
// }
// func (m *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
// resp := &ai.EmbedResponse{}
// for range req.Input {
// // Generate a simple embedding (768-dimensional vector of ones)
// embedding := make([]float32, 768)
// for i := range embedding {
// embedding[i] = 1.0
// }
// resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding})
// }
// return resp, nil
// }
// // dropCollection cleans up a test collection.
// func dropCollection(ctx context.Context, client client.Client, collectionName string) error {
// exists, err := client.HasCollection(ctx, collectionName)
// if err != nil {
// return fmt.Errorf("check collection: %w", err)
// }
// if exists {
// err = client.DropCollection(ctx, collectionName)
// if err != nil {
// return fmt.Errorf("drop collection: %w", err)
// }
// }
// return nil
// }
// func TestMilvusIntegration(t *testing.T) {
// ctx := context.Background()
// // Initialize Milvus plugin
// ms := Milvus{
// Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
// }
// // Initialize Genkit with Milvus plugin
// g, err := genkit.Init(ctx, genkit.WithPlugins(&ms))
// if err != nil {
// t.Fatalf("genkit.Init failed: %v", err)
// }
// // Get the Milvus client for cleanup
// m, ok := genkit.LookupPlugin(g, provider).(*Milvus)
// if !ok {
// t.Fatalf("Failed to lookup Milvus plugin")
// }
// defer m.client.Close()
// // Generate unique collection name
// collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano())
// // Configure collection
// cfg := CollectionConfig{
// Collection: collectionName,
// Dimension: 768, // Match mock embedder dimension
// Embedder: &MockEmbedder{},
// EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
// }
// // Define indexer and retriever
// indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg)
// if err != nil {
// t.Fatalf("DefineIndexerAndRetriever failed: %v", err)
// }
// // Clean up collection after test
// defer func() {
// if err := dropCollection(ctx, m.client, collectionName); err != nil {
// t.Errorf("Cleanup failed: %v", err)
// }
// }()
// t.Run("Index and Retrieve", func(t *testing.T) {
// // Index documents
// documents := []*ai.Document{
// {
// Content: []*ai.Part{ai.NewTextPart("Hello world")},
// Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"},
// },
// {
// Content: []*ai.Part{ai.NewTextPart("AI is amazing")},
// Metadata: map[string]interface{}{"id": int64(2), "category": "example"},
// },
// }
// req := &ai.IndexerRequest{Documents: documents}
// err := indexer.Index(ctx, req)
// if err != nil {
// t.Fatalf("Index failed: %v", err)
// }
// // Wait briefly to ensure Milvus processes the index
// time.Sleep(1 * time.Second)
// // Retrieve documents
// queryReq := &ai.RetrieverRequest{
// Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
// Options: &RetrieverOptions{
// Count: 2,
// MetricType: "L2",
// },
// }
// resp, err := retriever.Retrieve(ctx, queryReq)
// if err != nil {
// t.Fatalf("Retrieve failed: %v", err)
// }
// // Verify results
// assert.NotNil(t, resp, "Response should not be nil")
// assert.NotEmpty(t, resp.Documents, "Should return at least one document")
// for _, doc := range resp.Documents {
// assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty")
// // Note: Mock embedder returns identical vectors, so results may not be exact
// if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") {
// continue
// }
// t.Errorf("Unexpected document text: %s", doc.Content[0].Text)
// }
// })
// t.Run("Empty Index", func(t *testing.T) {
// req := &ai.IndexerRequest{Documents: []*ai.Document{}}
// err := indexer.Index(ctx, req)
// assert.NoError(t, err, "Indexing empty documents should succeed")
// })
// t.Run("Invalid Retrieve Options", func(t *testing.T) {
// queryReq := &ai.RetrieverRequest{
// Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
// Options: &RetrieverOptions{MetricType: "INVALID"},
// }
// _, err := retriever.Retrieve(ctx, queryReq)
// assert.Error(t, err, "Should fail with invalid metric type")
// assert.Contains(t, err.Error(), "unsupported metric type")
// })
// t.Run("Invalid Embedder Options", func(t *testing.T) {
// // Test with invalid EmbedderOptions type
// invalidCfg := CollectionConfig{
// Collection: collectionName + "_invalid",
// Dimension: 768,
// Embedder: &MockEmbedder{},
// EmbedderOptions: "not-a-map", // Invalid type
// }
// _, _, err := DefineIndexerAndRetriever(ctx, g, invalidCfg)
// assert.Error(t, err, "Should fail with invalid EmbedderOptions type")
// assert.Contains(t, err.Error(), "EmbedderOptions must be a map[string]interface{}")
// })
// }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment