Commit ba8c3fb3 authored by Wade's avatar Wade

chat milvus grap param

parent 5908072a
...@@ -5,12 +5,9 @@ import ( ...@@ -5,12 +5,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag" "github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge" "github.com/wade-liwei/agentchat/plugins/knowledge"
...@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) {
ctxAsJson, _ := json.Marshal(ctx)
log.Info().Msgf("input----ctxAsJson----%s", string(ctxAsJson))
inputAsJson, err := json.Marshal(input) inputAsJson, err := json.Marshal(input)
if err != nil { if err != nil {
return Response{ return Response{
Code: 500, Code: 500,
...@@ -290,10 +284,10 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -290,10 +284,10 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Summary = *lastQa.Summary promptInput.Summary = *lastQa.Summary
} }
if input.Milvus {
metaData := make(map[string]any) metaData := make(map[string]any)
metaData[util.UserIdKey] = input.ToID metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData) dRequest := ai.DocumentFromText(input.Content, metaData)
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest)) response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil { if err != nil {
...@@ -308,9 +302,9 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -308,9 +302,9 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Context = sb.String() promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context) log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
} }
}
begin := time.Now() if input.Graph {
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest)) graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest))
if err != nil { if err != nil {
log.Error().Msgf("graph Retrieve err.Error() %s", err.Error()) log.Error().Msgf("graph Retrieve err.Error() %s", err.Error())
...@@ -323,8 +317,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -323,8 +317,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Graph = sb.String() promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph) log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
} }
}
fmt.Println("graph time", time.Since(begin).Seconds())
simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model) simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model)
if err != nil { if err != nil {
......
...@@ -81,7 +81,6 @@ func main() { ...@@ -81,7 +81,6 @@ func main() {
ctx := context.Background() ctx := context.Background()
// Initialize genkit with plugins using flag/env values // Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins( g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"}, &ollama.Ollama{ServerAddress: "http://localhost:11434"},
......
...@@ -59,7 +59,6 @@ var ( ...@@ -59,7 +59,6 @@ var (
} }
) )
// ListModels returns a map of media-supported models and their capabilities // ListModels returns a map of media-supported models and their capabilities
func ListModels() (map[string]ai.ModelInfo, error) { func ListModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels)) models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
...@@ -81,8 +80,6 @@ func ListModels() (map[string]ai.ModelInfo, error) { ...@@ -81,8 +80,6 @@ func ListModels() (map[string]ai.ModelInfo, error) {
return models, nil return models, nil
} }
func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// o.mu.Lock() // o.mu.Lock()
// defer o.mu.Unlock() // defer o.mu.Unlock()
......
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
const UserNameKey = "username" const UserNameKey = "username"
const UserIdKey = "user_id" const UserIdKey = "user_id"
// Data returns the content type and bytes of the media part. // Data returns the content type and bytes of the media part.
func Data(p *ai.Part) (string, []byte, error) { func Data(p *ai.Part) (string, []byte, error) {
if !p.IsMedia() && !p.IsData() { if !p.IsMedia() && !p.IsData() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment