// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

// Package milvus provides a Genkit plugin for Milvus vector database using milvus-sdk-go.
package milvus

import (
	"context"
	"errors"
	"fmt"
	"os"
	"strings"
	"sync"

	"github.com/firebase/genkit/go/ai"
	"github.com/firebase/genkit/go/genkit"
	"github.com/milvus-io/milvus-sdk-go/v2/client"
	"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

// The provider used in the registry.
const provider = "milvus"

// Field names for Milvus schema.
const (
    idField       = "id"
    vectorField   = "vector"
    textField     = "text"
    metadataField = "metadata"
)

// Milvus holds configuration for the plugin.
type Milvus struct {
    // Milvus server address (host:port, e.g., "localhost:19530").
    // Defaults to MILVUS_ADDRESS environment variable.
    Addr string
    // Username for authentication.
    // Defaults to MILVUS_USERNAME.
    Username string
    // Password for authentication.
    // Defaults to MILVUS_PASSWORD.
    Password string
    // Token for authentication (alternative to username/password).
    // Defaults to MILVUS_TOKEN.
    Token string

    client  client.Client // Milvus client.
    mu      sync.Mutex    // Mutex to control access.
    initted bool          // Whether the plugin has been initialized.
}

// Name returns the plugin name.
func (m *Milvus) Name() string {
    return provider
}

// Init initializes the Milvus plugin.
func (m *Milvus) Init(ctx context.Context, g *genkit.Genkit) (err error) {
    if m == nil {
        m = &Milvus{}
    }

    m.mu.Lock()
    defer m.mu.Unlock()
    defer func() {
        if err != nil {
            err = fmt.Errorf("milvus.Init: %w", err)
        }
    }()

    if m.initted {
        return errors.New("plugin already initialized")
    }

    // Load configuration.
    addr := m.Addr
    if addr == "" {
        addr = os.Getenv("MILVUS_ADDRESS")
    }
    if addr == "" {
        return errors.New("milvus address required")
    }

    username := m.Username
    if username == "" {
        username = os.Getenv("MILVUS_USERNAME")
    }
    password := m.Password
    if password == "" {
        password = os.Getenv("MILVUS_PASSWORD")
    }
    token := m.Token
    if token == "" {
        token = os.Getenv("MILVUS_TOKEN")
    }

    // Initialize Milvus client (inspired by examples/simple/main.go).
    config := client.Config{
        Address:  addr,
        Username: username,
        Password: password,
        APIKey:   token,
    }
    client, err := client.NewClient(ctx, config)
    if err != nil {
        return fmt.Errorf("failed to initialize Milvus client: %v", err)
    }

    // Verify connection.

    // if err := client.Connect(ctx); err != nil {
    //     return fmt.Errorf("failed to connect to Milvus: %v", err)
    // }

    m.client = client
    m.initted = true
    return nil
}

// CollectionConfig holds configuration for an indexer/retriever pair.
type CollectionConfig struct {
    // Milvus collection name. Must not be empty.
    Collection string
    // Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
    Dimension int
    // Embedder for generating vectors.
    Embedder ai.Embedder
    // Embedder options.
    EmbedderOptions any
}

// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit, cfg CollectionConfig) (ai.Indexer, ai.Retriever, error) {
    if cfg.Embedder == nil {
        return nil, nil, errors.New("milvus: Embedder required")
    }
    if cfg.Collection == "" {
        return nil, nil, errors.New("milvus: collection name required")
    }
    if cfg.Dimension <= 0 {
        return nil, nil, errors.New("milvus: dimension must be positive")
    }

    m := genkit.LookupPlugin(g, provider)
    if m == nil {
        return nil, nil, errors.New("milvus plugin not found; did you call genkit.Init with the milvus plugin?")
    }
    milvus := m.(*Milvus)

    ds, err := milvus.newDocStore(ctx, &cfg)
    if err != nil {
        return nil, nil, err
    }
    
    indexer := genkit.DefineIndexer(g, provider, cfg.Collection, ds.Index)
    retriever := genkit.DefineRetriever(g, provider, cfg.Collection, ds.Retrieve)
    return indexer, retriever, nil
}

// docStore defines an Indexer and a Retriever.
type docStore struct {
    client          client.Client
    collection      string
    dimension       int
    embedder        ai.Embedder
    // embedderOptions any
    embedderOptions map[string]interface{}
}

// newDocStore creates a docStore (inspired by examples/simple/main.go).
func (m *Milvus) newDocStore(ctx context.Context, cfg *CollectionConfig) (*docStore, error) {
    if m.client == nil {
        return nil, errors.New("milvus.Init not called")
    }

    // Check/create collection.
    exists, err := m.client.HasCollection(ctx, cfg.Collection)
    if err != nil {
        return nil, fmt.Errorf("failed to check collection %q: %v", cfg.Collection, err)
    }
    if !exists {
        // Define schema.
        schema := &entity.Schema{
            CollectionName: cfg.Collection,
            Fields: []*entity.Field{
                {
                    Name:       idField,
                    DataType:   entity.FieldTypeInt64,
                    PrimaryKey: true,
                    AutoID:     true,
                },
                {
                    Name:     vectorField,
                    DataType: entity.FieldTypeFloatVector,
                    TypeParams: map[string]string{
                        "dim": fmt.Sprintf("%d", cfg.Dimension),
                    },
                },
                {
                    Name:      textField,
                    DataType:  entity.FieldTypeVarChar,
                },
                {
                    Name:     metadataField,
                    DataType: entity.FieldTypeJSON,
                },
            },
        }

        err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
        if err != nil {
            return nil, fmt.Errorf("failed to create collection %q: %v", cfg.Collection, err)
        }

        // Create HNSW index (per examples/index/main.go).
		index,err  :=entity.NewIndexHNSW(
            entity.L2, // Distance metric.
            8,         // M
            96,        // efConstruction			
		)

		if err != nil {
            return nil, fmt.Errorf("entity.NewIndexHNSW  %v", err)
        }

        err = m.client.CreateIndex(ctx, cfg.Collection, vectorField, index, false)
        if err != nil {
            return nil, fmt.Errorf("failed to create index for collection %q: %v", cfg.Collection, err)
        }
    }

    // Load collection (per examples/simple/main.go).
    err = m.client.LoadCollection(ctx, cfg.Collection, false)
    if err != nil {
        return nil, fmt.Errorf("failed to load collection %q: %v", cfg.Collection, err)
    }

    return &docStore{
        client:          m.client,
        collection:      cfg.Collection,
        dimension:       cfg.Dimension,
        embedder:        cfg.Embedder,
        embedderOptions: cfg.EmbedderOptions,
    }, nil
}

// Indexer returns the indexer for a collection.
func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
    return genkit.LookupIndexer(g, provider, collection)
}

// Retriever returns the retriever for a collection.
func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
    return genkit.LookupRetriever(g, provider, collection)
}


// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
    if len(req.Documents) == 0 {
        return nil
    }

    // Embed documents.
    ereq := &ai.EmbedRequest{
        Input:   req.Documents,
        Options: ds.embedderOptions,
    }
    eres, err := ds.embedder.Embed(ctx, ereq)
    if err != nil {
        return fmt.Errorf("milvus index embedding failed: %w", err)
    }

    // Prepare row-based data.
    var rows []interface{}
    for i, emb := range eres.Embeddings {
        doc := req.Documents[i]
        var sb strings.Builder
        for _, p := range doc.Content {
            if p.IsText() {
                sb.WriteString(p.Text)
            }
        }
        text := sb.String()
        metadata := doc.Metadata
        if metadata == nil {
            metadata = make(map[string]interface{})
        }

        // Create row as map[string]interface{}
        row := make(map[string]interface{})
        row["vector"] = entity.FloatVector(emb.Embedding) // Vector field
        row["text"] = text                               // Text field
        for k, v := range metadata {
            row[k] = v // Add metadata fields dynamically
        }
        rows = append(rows, row)
    }

    // Insert rows into Milvus
    _, err = ds.client.InsertRows(ctx, ds.collection, "", rows)
    if err != nil {
        return fmt.Errorf("milvus insert rows failed: %w", err)
    }

    return nil
}

// RetrieverOptions for Milvus retrieval.
type RetrieverOptions struct {
    Count      int    `json:"count,omitempty"`      // Max documents to retrieve.
    MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}


// Retrieve implements the Retriever.Retrieve method (inspired by examples/search/main.go).
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
    count := 3 // Default.
    metricTypeStr := "L2"
    if req.Options != nil {
        ropt, ok := req.Options.(*RetrieverOptions)
        if !ok {
            return nil, fmt.Errorf("milvus.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
        }
        if ropt.Count > 0 {
            count = ropt.Count
        }
        if ropt.MetricType != "" {
            metricTypeStr = ropt.MetricType
        }
    }

    // Map string metric type to entity.MetricType.
    var metricType entity.MetricType
    switch metricTypeStr {
    case "L2":
        metricType = entity.L2
    case "IP":
        metricType = entity.IP
    default:
        return nil, fmt.Errorf("unsupported metric type: %s", metricTypeStr)
    }

    // Embed query.
    ereq := &ai.EmbedRequest{
        Input: []*ai.Document{req.Query}, // Fixed: Use req.Document instead of req.Query.Content
        Options:   ds.embedderOptions,
    }
    eres, err := ds.embedder.Embed(ctx, ereq)
    if err != nil {
        return nil, fmt.Errorf("milvus retrieve embedding failed: %v", err)
    }
    if len(eres.Embeddings) == 0 {
        return nil, errors.New("no embeddings generated for query")
    }
    queryVector := entity.FloatVector(eres.Embeddings[0].Embedding)

    // Create search parameters.
    searchParams, err := entity.NewIndexHNSWSearchParam(64) // ef
    if err != nil {
        return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
    }

    // Perform search.
    results, err := ds.client.Search(
        ctx,
        ds.collection,
        []string{},                    // partitions: empty for all partitions
        "",                            // expr: empty for no filtering
        []string{textField, metadataField}, // output fields
        []entity.Vector{queryVector},  // vectors
        vectorField,                   // vector field
        metricType,                    // metric type
        count,                         // topK
        searchParams,                  // search parameters
    ) // opts: omitted
    if err != nil {
        return nil, fmt.Errorf("milvus search failed: %v", err)
    }

    // Process results.
    var docs []*ai.Document
    for _, result := range results {
        for i := 0; i < result.ResultCount; i++ {
            textCol := result.Fields.GetColumn(textField)

            text, err := textCol.GetAsString(i)
            if err != nil {
                continue
            }
            //metadataCol := result.Fields.GetColumn(metadataField)
            // metadataBytes, err := metadataCol.Get(i)
            // if err != nil {
            //     continue
            // }
            var metadata map[string]interface{}
            // if len(metadataBytes) > 0 {
            //     if err := json.Unmarshal(metadataBytes.([]byte), &metadata); err != nil {
            //         continue
            //     }
            // }
            doc := ai.DocumentFromText(text, metadata)
            docs = append(docs, doc)
        }
    }

    return &ai.RetrieverResponse{
        Documents: docs,
    }, nil
}
