Commit d7a0aa40 authored by Wade's avatar Wade

add plugins/milvus/milvus_test.go

parent 854f7d9a
......@@ -17,6 +17,7 @@ require (
github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
......@@ -44,7 +45,9 @@ require (
github.com/ollama/ollama v0.6.5 // indirect
github.com/pgvector/pgvector-go v0.3.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/test-go/testify v1.1.4 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
......
......@@ -253,6 +253,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE=
github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
......
......@@ -164,6 +164,7 @@ func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg Collec
if err != nil {
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
return indexer, retriever, nil
......@@ -175,7 +176,8 @@ type docStore struct {
collection string
dimension int
embedder ai.Embedder
embedderOptions any
// embedderOptions any
embedderOptions map[string]interface{}
}
// newDocStore creates a docStore (inspired by examples/simple/main.go).
......@@ -265,7 +267,8 @@ func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection)
}
// Index implements the Indexer.Index method (inspired by examples/insert/main.go).
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 {
return nil
......@@ -273,18 +276,16 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// Embed documents.
ereq := &ai.EmbedRequest{
Input: req.Documents,
Options: ds.embedderOptions,
Input: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("milvus index embedding failed: %v", err)
return fmt.Errorf("milvus index embedding failed: %w", err)
}
// Prepare columns.
var vectors []entity.Vector
var texts []string
var metadatas []map[string]interface{}
// Prepare row-based data.
var rows []interface{}
for i, emb := range eres.Embeddings {
doc := req.Documents[i]
var sb strings.Builder
......@@ -299,20 +300,20 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
metadata = make(map[string]interface{})
}
vectors = append(vectors, entity.FloatVector(emb.Embedding))
texts = append(texts, text)
metadatas = append(metadatas, metadata)
// Create row as map[string]interface{}
row := make(map[string]interface{})
row["vector"] = entity.FloatVector(emb.Embedding) // Vector field
row["text"] = text // Text field
for k, v := range metadata {
row[k] = v // Add metadata fields dynamically
}
rows = append(rows, row)
}
_, err = ds.client.Insert(ctx, ds.collection, "",
entity.NewColumnFloatVector(vectorField, ds.dimension, vectors),
entity.NewColumnVarChar(textField, texts),
entity.NewColumnJSON(metadataField, metadatas),
)
// Insert rows into Milvus
_, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
if err != nil {
return fmt.Errorf("milvus insert failed: %v", err)
return fmt.Errorf("milvus insert rows failed: %w", err)
}
return nil
......
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package milvus
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/test-go/testify/assert"
)
// MockEmbedder is a mock implementation of ai.Embedder.
type MockEmbedder struct{}
func (m *MockEmbedder) Name() string {
return "mock-embedder"
}
func (m *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
resp := &ai.EmbedResponse{}
for range req.Input {
// Generate a simple embedding (768-dimensional vector of ones)
embedding := make([]float32, 768)
for i := range embedding {
embedding[i] = 1.0
}
resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding})
}
return resp, nil
}
// dropCollection cleans up a test collection.
func dropCollection(ctx context.Context, client client.Client, collectionName string) error {
exists, err := client.HasCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("check collection: %w", err)
}
if exists {
err = client.DropCollection(ctx, collectionName)
if err != nil {
return fmt.Errorf("drop collection: %w", err)
}
}
return nil
}
func TestMilvusIntegration(t *testing.T) {
ctx := context.Background()
// Initialize Genkit
g := &genkit.Genkit{}
// Initialize Milvus plugin
m := &Milvus{
Addr: "54.92.111.204:19530", // Milvus gRPC endpoint
}
err := m.Init(ctx, g)
if err != nil {
t.Fatalf("Milvus.Init failed: %v", err)
}
defer m.client.Close()
// Generate unique collection name
collectionName := fmt.Sprintf("test_collection_%d", time.Now().UnixNano())
// Configure collection
cfg := CollectionConfig{
Collection: collectionName,
Dimension: 768, // Match mock embedder dimension
Embedder: &MockEmbedder{},
EmbedderOptions: map[string]interface{}{},
}
// Define indexer and retriever
indexer, retriever, err := DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil {
t.Fatalf("DefineIndexerAndRetriever failed: %v", err)
}
// Clean up collection after test
defer func() {
if err := dropCollection(ctx, m.client, collectionName); err != nil {
t.Errorf("Cleanup failed: %v", err)
}
}()
t.Run("Index and Retrieve", func(t *testing.T) {
// Index documents
documents := []*ai.Document{
{
Content: []*ai.Part{ai.NewTextPart("Hello world")},
Metadata: map[string]interface{}{"id": int64(1), "category": "greeting"},
},
{
Content: []*ai.Part{ai.NewTextPart("AI is amazing")},
Metadata: map[string]interface{}{"id": int64(2), "category": "tech"},
},
}
req := &ai.IndexerRequest{Documents: documents}
err = indexer.Index(ctx, req)
if err != nil {
t.Fatalf("Index failed: %v", err)
}
// Wait briefly to ensure Milvus processes the index
time.Sleep(1 * time.Second)
// Retrieve documents
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{
Count: 2,
MetricType: "L2",
},
}
resp, err := retriever.Retrieve(ctx, queryReq)
if err != nil {
t.Fatalf("Retrieve failed: %v", err)
}
// Verify results
assert.NotNil(t, resp, "Response should not be nil")
assert.NotEmpty(t, resp.Documents, "Should return at least one document")
for _, doc := range resp.Documents {
assert.NotEmpty(t, doc.Content[0].Text, "Document text should not be empty")
// Note: Mock embedder returns identical vectors, so results may not be exact
if strings.Contains(doc.Content[0].Text, "Hello world") || strings.Contains(doc.Content[0].Text, "AI is amazing") {
continue
}
t.Errorf("Unexpected document text: %s", doc.Content[0].Text)
}
})
t.Run("Empty Index", func(t *testing.T) {
req := &ai.IndexerRequest{Documents: []*ai.Document{}}
err = indexer.Index(ctx, req)
assert.NoError(t, err, "Indexing empty documents should succeed")
})
t.Run("Invalid Retrieve Options", func(t *testing.T) {
queryReq := &ai.RetrieverRequest{
Query: &ai.Document{Content: []*ai.Part{ai.NewTextPart("Hello world")}},
Options: &RetrieverOptions{MetricType: "INVALID"},
}
_, err = retriever.Retrieve(ctx, queryReq)
assert.Error(t, err, "Should fail with invalid metric type")
assert.Contains(t, err.Error(), "unsupported metric type")
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment