// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package graphrag

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"strconv"
	"strings"
	"sync"

	"github.com/firebase/genkit/go/ai"
	"github.com/firebase/genkit/go/genkit"
)

// Client 知识库客户端
type Client struct {
	BaseURL string // 基础URL，例如 "http://54.92.111.204:5670"
}

// SpaceRequest 创建空间的请求结构体
type SpaceRequest struct {
	ID         int    `json:"id"`
	Name       string `json:"name"`
	VectorType string `json:"vector_type"`
	DomainType string `json:"domain_type"`
	Desc       string `json:"desc"`
	Owner      string `json:"owner"`
	SpaceID    int    `json:"space_id"`
}

// DocumentRequest 添加文档的请求结构体
type DocumentRequest struct {
	DocName   string                 `json:"doc_name"`
	DocID     int                    `json:"doc_id"`
	DocType   string                 `json:"doc_type"`
	DocToken  string                 `json:"doc_token"`
	Content   string                 `json:"content"`
	Source    string                 `json:"source"`
	Labels    string                 `json:"labels"`
	Questions []string               `json:"questions"`
	Metadata  map[string]interface{} `json:"metadata"`
}

// ChunkParameters 分片参数
type ChunkParameters struct {
	ChunkStrategy string `json:"chunk_strategy"`
	TextSplitter  string `json:"text_splitter"`
	SplitterType  string `json:"splitter_type"`
	ChunkSize     int    `json:"chunk_size"`
	ChunkOverlap  int    `json:"chunk_overlap"`
	Separator     string `json:"separator"`
	EnableMerge   bool   `json:"enable_merge"`
}

// SyncBatchRequest 同步批量处理的请求结构体
type SyncBatchRequest struct {
	DocID           int             `json:"doc_id"`
	SpaceID         string          `json:"space_id"`
	ModelName       string          `json:"model_name"`
	ChunkParameters ChunkParameters `json:"chunk_parameters"`
}

// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
	return &Client{
		BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
	}
}

// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
	url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
	body, err := json.Marshal(req)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Accept", "application/json")
	httpReq.Header.Set("Content-Type", "application/json")

	client := &http.Client{}
	resp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("failed to send request: %w", err)
	}

	return resp, nil
}

// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
	url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
	body, err := json.Marshal(req)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Accept", "application/json")
	httpReq.Header.Set("Content-Type", "application/json")

	client := &http.Client{}
	resp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("failed to send request: %w", err)
	}

	return resp, nil
}

// SyncBatchDocument 同步批量处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
	url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
	body, err := json.Marshal(req)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Accept", "application/json")
	httpReq.Header.Set("Content-Type", "application/json")

	client := &http.Client{}
	resp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("failed to send request: %w", err)
	}

	return resp, nil
}

// The provider used in the registry.
const provider = "graphrag"

// Field names for schema.
const (
	idField       = "id"
	textField     = "text"
	metadataField = "metadata"
)

// GraphKnowledge holds configuration for the plugin.
type GraphKnowledge struct {
	Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670").

	client  *Client    // Knowledge client.
	mu      sync.Mutex // Mutex to control access.
	initted bool       // Whether the plugin has been initialized.
}

// Name returns the plugin name.
func (k *GraphKnowledge) Name() string {
	return provider
}

// Init initializes the GraphKnowledge plugin.
func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
	if k == nil {
		k = &GraphKnowledge{}
	}

	k.mu.Lock()
	defer k.mu.Unlock()
	defer func() {
		if err != nil {
			err = fmt.Errorf("graphrag.Init: %w", err)
		}
	}()

	if k.initted {
		return errors.New("plugin already initialized")
	}

	// Load configuration.
	addr := k.Addr
	if addr == "" {
		addr = "54.92.111.204:5670" // Default address.
	}

	// Initialize Knowledge client.
	host, port := parseAddr(addr)
	client := NewClient(host, port)
	k.client = client
	k.initted = true
	return nil
}

// parseAddr splits host:port into host and port.
func parseAddr(addr string) (string, int) {
	parts := strings.Split(addr, ":")
	if len(parts) != 2 {
		return "54.92.111.204", 5670
	}
	port, _ := strconv.Atoi(parts[1])
	return parts[0], port
}

// DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {

	spaceID := ""
	modelName := ""

	k := genkit.LookupPlugin(g, provider)
	if k == nil {
		return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
	}
	knowledge := k.(*GraphKnowledge)

	ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
	if err != nil {
		return nil, nil, err
	}

	indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
	retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
	return indexer, retriever, nil
}

// docStore defines an Indexer and a Retriever.
type docStore struct {
	client    *Client
	spaceID   string
	modelName string
}

// newDocStore creates a docStore.
func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
	if k.client == nil {
		return nil, errors.New("graphrag.Init not called")
	}

	return &docStore{
		client:    k.client,
		spaceID:   spaceID,
		modelName: modelName,
	}, nil
}

// Indexer returns the indexer for a space.
func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
	return genkit.LookupIndexer(g, provider, spaceID)
}

// Retriever returns the retriever for a space.
func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
	return genkit.LookupRetriever(g, provider, spaceID)
}

// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
	if len(req.Documents) == 0 {
		return nil
	}

	// Create knowledge space.
	spaceReq := SpaceRequest{
		ID:         1,
		Name:       ds.spaceID,
		VectorType: "hnsw",
		DomainType: "Normal",
		Desc:       "Default knowledge space",
		Owner:      "admin",
		SpaceID:    1,
	}
	resp, err := ds.client.AddSpace(spaceReq)
	if err != nil {
		return fmt.Errorf("add space: %w", err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
	}

	// Index each document.
	for i, doc := range req.Documents {
		// Ensure metadata includes user_id and username.
		if doc.Metadata == nil {
			doc.Metadata = make(map[string]interface{})
		}
		if _, ok := doc.Metadata["user_id"]; !ok {
			doc.Metadata["user_id"] = "user123" // Mock data.
		}
		if _, ok := doc.Metadata["username"]; !ok {
			doc.Metadata["username"] = "Alice" // Mock data.
		}

		// Add document.
		var sb strings.Builder
		for _, p := range doc.Content {
			if p.IsText() {
				sb.WriteString(p.Text)
			}
		}
		text := sb.String()
		docReq := DocumentRequest{
			DocName:   fmt.Sprintf("doc_%d", i+1),
			DocID:     i + 1,
			DocType:   "text",
			DocToken:  "",
			Content:   text,
			Source:    "api",
			Labels:    "",
			Questions: []string{},
			Metadata:  doc.Metadata,
		}
		resp, err := ds.client.AddDocument(ds.spaceID, docReq)
		if err != nil {
			return fmt.Errorf("add document %d: %w", i+1, err)
		}
		defer resp.Body.Close()
		if resp.StatusCode != http.StatusOK {
			body, _ := io.ReadAll(resp.Body)
			return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
		}

		// Sync document for embedding.
		syncReq := []SyncBatchRequest{
			{
				DocID:     docReq.DocID,
				SpaceID:   ds.spaceID,
				ModelName: ds.modelName,
				ChunkParameters: ChunkParameters{
					ChunkStrategy: "sentence",
					TextSplitter:  "recursive",
					SplitterType:  "user_define",
					ChunkSize:     512,
					ChunkOverlap:  50,
					Separator:     "\n",
					EnableMerge:   true,
				},
			},
		}
		syncResp, err := ds.client.SyncBatchDocument(ds.spaceID, syncReq)
		if err != nil {
			return fmt.Errorf("sync batch document %d: %w", i+1, err)
		}
		defer syncResp.Body.Close()
		if resp.StatusCode != http.StatusOK {
			body, _ := io.ReadAll(syncResp.Body)
			return fmt.Errorf("sync batch document %d failed with status %d: %s", i+1, syncResp.StatusCode, string(body))
		}
	}

	return nil
}

// RetrieverOptions for Knowledge retrieval.
type RetrieverOptions struct {
	Count      int    `json:"count,omitempty"`       // Max documents to retrieve.
	MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}

// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
	// count := 3
	// metricTypeStr := "L2"
	// if req.Options != nil {
	//     ropt, ok := req.Options.(*RetrieverOptions)
	//     if !ok {
	//         return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
	//     }
	//     if ropt.Count > 0 {
	//         count = ropt.Count
	//     }
	//     if ropt.MetricType != "" {
	//         metricTypeStr = ropt.MetricType
	//     }
	// }

	// Format query for retrieval.
	queryText := fmt.Sprintf("Search for: %s", req.Query.Content)
	username := "Alice" // Default, override if metadata available.
	if req.Query.Metadata != nil {
		if uname, ok := req.Query.Metadata["username"].(string); ok {
			username = uname
		}
	}

	// Prepare request for chat completions endpoint.
	url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL)
	chatReq := struct {
		ConvUID      string                 `json:"conv_uid"`
		UserInput    string                 `json:"user_input"`
		UserName     string                 `json:"user_name"`
		ChatMode     string                 `json:"chat_mode"`
		AppCode      string                 `json:"app_code"`
		Temperature  float32                `json:"temperature"`
		MaxNewTokens int                    `json:"max_new_tokens"`
		SelectParam  string                 `json:"select_param"`
		ModelName    string                 `json:"model_name"`
		Incremental  bool                   `json:"incremental"`
		SysCode      string                 `json:"sys_code"`
		PromptCode   string                 `json:"prompt_code"`
		ExtInfo      map[string]interface{} `json:"ext_info"`
	}{
		ConvUID:      "",
		UserInput:    queryText,
		UserName:     username,
		ChatMode:     "",
		AppCode:      "",
		Temperature:  0.5,
		MaxNewTokens: 4000,
		SelectParam:  "",
		ModelName:    ds.modelName,
		Incremental:  false,
		SysCode:      "",
		PromptCode:   "",
		ExtInfo: map[string]interface{}{
			"space_id": ds.spaceID,
			//"k":        count,
		},
	}

	body, err := json.Marshal(chatReq)
	if err != nil {
		return nil, fmt.Errorf("marshal chat request: %w", err)
	}

	httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
	if err != nil {
		return nil, fmt.Errorf("create chat request: %w", err)
	}
	httpReq.Header.Set("Accept", "application/json")
	httpReq.Header.Set("Content-Type", "application/json")

	client := &http.Client{}
	resp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("send chat request: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		body, _ := io.ReadAll(resp.Body)
		return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
	}

	// Parse response
	var chatResp struct {
		Success bool `json:"success"`
		Data    struct {
			Answer []struct {
				Content  string                 `json:"content"`
				DocID    int                    `json:"doc_id"`
				Score    float64                `json:"score"`
				Metadata map[string]interface{} `json:"metadata_map"`
			} `json:"answer"`
		} `json:"data"`
	}
	if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
		return nil, fmt.Errorf("decode chat response: %w", err)
	}

	var docs []*ai.Document
	for _, doc := range chatResp.Data.Answer {
		metadata := doc.Metadata
		if metadata == nil {
			metadata = make(map[string]interface{})
		}
		// Ensure metadata includes user_id and username.
		if _, ok := metadata["user_id"]; !ok {
			metadata["user_id"] = "user123"
		}
		if _, ok := metadata["username"]; !ok {
			metadata["username"] = username
		}
		aiDoc := ai.DocumentFromText(doc.Content, metadata)
		docs = append(docs, aiDoc)
	}

	return &ai.RetrieverResponse{
		Documents: docs,
	}, nil
}
