Commit ba8c3fb3 authored by Wade's avatar Wade

chat milvus grap param

parent 5908072a
......@@ -5,12 +5,9 @@ import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge"
......@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
// Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) {
ctxAsJson, _ := json.Marshal(ctx)
log.Info().Msgf("input----ctxAsJson----%s", string(ctxAsJson))
inputAsJson, err := json.Marshal(input)
if err != nil {
return Response{
Code: 500,
......@@ -290,10 +284,10 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Summary = *lastQa.Summary
}
if input.Milvus {
metaData := make(map[string]any)
metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData)
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil {
......@@ -308,9 +302,9 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
}
}
begin := time.Now()
if input.Graph {
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest))
if err != nil {
log.Error().Msgf("graph Retrieve err.Error() %s", err.Error())
......@@ -323,8 +317,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
}
fmt.Println("graph time", time.Since(begin).Seconds())
}
simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model)
if err != nil {
......
......@@ -81,7 +81,6 @@ func main() {
ctx := context.Background()
// Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"},
......
......@@ -59,7 +59,6 @@ var (
}
)
// ListModels returns a map of media-supported models and their capabilities
func ListModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
......@@ -81,8 +80,6 @@ func ListModels() (map[string]ai.ModelInfo, error) {
return models, nil
}
func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// o.mu.Lock()
// defer o.mu.Unlock()
......
......@@ -11,8 +11,6 @@ import (
const UserNameKey = "username"
const UserIdKey = "user_id"
// Data returns the content type and bytes of the media part.
func Data(p *ai.Part) (string, []byte, error) {
if !p.IsMedia() && !p.IsData() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment