Commit ba8c3fb3 authored by Wade's avatar Wade

chat milvus grap param

parent 5908072a
...@@ -5,12 +5,9 @@ import ( ...@@ -5,12 +5,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag" "github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge" "github.com/wade-liwei/agentchat/plugins/knowledge"
...@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) {
ctxAsJson, _ := json.Marshal(ctx)
log.Info().Msgf("input----ctxAsJson----%s", string(ctxAsJson))
inputAsJson, err := json.Marshal(input) inputAsJson, err := json.Marshal(input)
if err != nil { if err != nil {
return Response{ return Response{
Code: 500, Code: 500,
...@@ -290,42 +284,41 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -290,42 +284,41 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Summary = *lastQa.Summary promptInput.Summary = *lastQa.Summary
} }
metaData := make(map[string]any) if input.Milvus {
metaData[util.UserIdKey] = input.ToID metaData := make(map[string]any)
metaData[util.UserNameKey] = input.To metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData) dRequest := ai.DocumentFromText(input.Content, metaData)
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest)) response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil { if err != nil {
log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error()) log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error())
} else { } else {
var sb strings.Builder var sb strings.Builder
for _, d := range response.Documents { for _, d := range response.Documents {
sb.WriteString(d.Content[0].Text) sb.WriteString(d.Content[0].Text)
sb.WriteByte('\n') sb.WriteByte('\n')
}
promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
} }
promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
} }
begin := time.Now() if input.Graph {
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest))
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest)) if err != nil {
if err != nil { log.Error().Msgf("graph Retrieve err.Error() %s", err.Error())
log.Error().Msgf("graph Retrieve err.Error() %s", err.Error()) } else {
} else { var sb strings.Builder
var sb strings.Builder for _, d := range graphResponse.Documents {
for _, d := range graphResponse.Documents { sb.WriteString(d.Content[0].Text)
sb.WriteString(d.Content[0].Text) sb.WriteByte('\n')
sb.WriteByte('\n') }
promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
} }
promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
} }
fmt.Println("graph time", time.Since(begin).Seconds())
simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model) simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model)
if err != nil { if err != nil {
return Response{ return Response{
......
...@@ -81,7 +81,6 @@ func main() { ...@@ -81,7 +81,6 @@ func main() {
ctx := context.Background() ctx := context.Background()
// Initialize genkit with plugins using flag/env values // Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins( g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"}, &ollama.Ollama{ServerAddress: "http://localhost:11434"},
......
...@@ -59,29 +59,26 @@ var ( ...@@ -59,29 +59,26 @@ var (
} }
) )
// ListModels returns a map of media-supported models and their capabilities // ListModels returns a map of media-supported models and their capabilities
func ListModels() (map[string]ai.ModelInfo, error) { func ListModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels)) models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
for _, modelName := range mediaSupportedModels { for _, modelName := range mediaSupportedModels {
// Normalize model name by removing version tags (e.g., "llava:13b" -> "llava") // Normalize model name by removing version tags (e.g., "llava:13b" -> "llava")
baseName := strings.Split(modelName, ":")[0] baseName := strings.Split(modelName, ":")[0]
models[modelName] = ai.ModelInfo{ models[modelName] = ai.ModelInfo{
Label: "Ollama - " + baseName, Label: "Ollama - " + baseName,
Supports: &ai.ModelSupports{ Supports: &ai.ModelSupports{
Multiturn: true, Multiturn: true,
SystemRole: true, SystemRole: true,
Media: true, // All models in mediaSupportedModels support media Media: true, // All models in mediaSupportedModels support media
Tools: false, // None of these models are in toolSupportedModels Tools: false, // None of these models are in toolSupportedModels
}, },
Versions: []string{}, Versions: []string{},
} }
} }
return models, nil
}
return models, nil
}
func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// o.mu.Lock() // o.mu.Lock()
...@@ -244,35 +241,35 @@ func (o *Ollama) Init(ctx context.Context, g *genkit.Genkit) (err error) { ...@@ -244,35 +241,35 @@ func (o *Ollama) Init(ctx context.Context, g *genkit.Genkit) (err error) {
} }
o.initted = true o.initted = true
// Register all supported models // Register all supported models
modelSet := make(map[string]struct{}) modelSet := make(map[string]struct{})
for _, modelName := range mediaSupportedModels { for _, modelName := range mediaSupportedModels {
modelSet[modelName] = struct{}{} modelSet[modelName] = struct{}{}
} }
for _, modelName := range toolSupportedModels { for _, modelName := range toolSupportedModels {
modelSet[modelName] = struct{}{} modelSet[modelName] = struct{}{}
} }
for modelName := range modelSet { for modelName := range modelSet {
// Determine model type // Determine model type
modelType := "chat" modelType := "chat"
if slices.Contains(mediaSupportedModels, modelName) && !slices.Contains(toolSupportedModels, modelName) { if slices.Contains(mediaSupportedModels, modelName) && !slices.Contains(toolSupportedModels, modelName) {
modelType = "vision" modelType = "vision"
} }
modelDef := ModelDefinition{ modelDef := ModelDefinition{
Name: modelName, Name: modelName,
Type: modelType, Type: modelType,
} }
o.DefineModel(g, modelDef, nil) o.DefineModel(g, modelDef, nil)
log.Info(). log.Info().
Str("method", "Ollama.Init"). Str("method", "Ollama.Init").
Str("model_name", modelName). Str("model_name", modelName).
Str("model_type", modelType). Str("model_type", modelType).
Msg("Registered model") Msg("Registered model")
} }
log.Info().Str("method", "Ollama.Init").Msg("Initialization successful") log.Info().Str("method", "Ollama.Init").Msg("Initialization successful")
return nil return nil
} }
......
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
const UserNameKey = "username" const UserNameKey = "username"
const UserIdKey = "user_id" const UserIdKey = "user_id"
// Data returns the content type and bytes of the media part. // Data returns the content type and bytes of the media part.
func Data(p *ai.Part) (string, []byte, error) { func Data(p *ai.Part) (string, []byte, error) {
if !p.IsMedia() && !p.IsData() { if !p.IsMedia() && !p.IsData() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment