Commit ba8c3fb3 authored by Wade's avatar Wade

chat milvus grap param

parent 5908072a
......@@ -5,12 +5,9 @@ import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge"
......@@ -241,10 +238,7 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
// Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *ChatInput) (Response, error) {
ctxAsJson, _ := json.Marshal(ctx)
log.Info().Msgf("input----ctxAsJson----%s", string(ctxAsJson))
inputAsJson, err := json.Marshal(input)
if err != nil {
return Response{
Code: 500,
......@@ -290,42 +284,41 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Summary = *lastQa.Summary
}
metaData := make(map[string]any)
metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData)
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil {
log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error())
} else {
var sb strings.Builder
for _, d := range response.Documents {
sb.WriteString(d.Content[0].Text)
sb.WriteByte('\n')
if input.Milvus {
metaData := make(map[string]any)
metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData)
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil {
log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error())
} else {
var sb strings.Builder
for _, d := range response.Documents {
sb.WriteString(d.Content[0].Text)
sb.WriteByte('\n')
}
promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
}
promptInput.Context = sb.String()
log.Info().Msgf("promptInput.Context: %s", promptInput.Context)
}
begin := time.Now()
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest))
if err != nil {
log.Error().Msgf("graph Retrieve err.Error() %s", err.Error())
} else {
var sb strings.Builder
for _, d := range graphResponse.Documents {
sb.WriteString(d.Content[0].Text)
sb.WriteByte('\n')
if input.Graph {
graphResponse, err := ai.Retrieve(ctx, graphRetriever, ai.WithDocs(dRequest))
if err != nil {
log.Error().Msgf("graph Retrieve err.Error() %s", err.Error())
} else {
var sb strings.Builder
for _, d := range graphResponse.Documents {
sb.WriteString(d.Content[0].Text)
sb.WriteByte('\n')
}
promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
}
promptInput.Graph = sb.String()
log.Info().Msgf("promptInput.Graph : %s", promptInput.Graph)
}
fmt.Println("graph time", time.Since(begin).Seconds())
simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model)
if err != nil {
return Response{
......
......@@ -81,7 +81,6 @@ func main() {
ctx := context.Background()
// Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"},
......
......@@ -59,29 +59,26 @@ var (
}
)
// ListModels returns a map of media-supported models and their capabilities
func ListModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
for _, modelName := range mediaSupportedModels {
// Normalize model name by removing version tags (e.g., "llava:13b" -> "llava")
baseName := strings.Split(modelName, ":")[0]
models[modelName] = ai.ModelInfo{
Label: "Ollama - " + baseName,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: true, // All models in mediaSupportedModels support media
Tools: false, // None of these models are in toolSupportedModels
},
Versions: []string{},
}
}
return models, nil
}
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
for _, modelName := range mediaSupportedModels {
// Normalize model name by removing version tags (e.g., "llava:13b" -> "llava")
baseName := strings.Split(modelName, ":")[0]
models[modelName] = ai.ModelInfo{
Label: "Ollama - " + baseName,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: true, // All models in mediaSupportedModels support media
Tools: false, // None of these models are in toolSupportedModels
},
Versions: []string{},
}
}
return models, nil
}
func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// o.mu.Lock()
......@@ -244,35 +241,35 @@ func (o *Ollama) Init(ctx context.Context, g *genkit.Genkit) (err error) {
}
o.initted = true
// Register all supported models
modelSet := make(map[string]struct{})
for _, modelName := range mediaSupportedModels {
modelSet[modelName] = struct{}{}
}
for _, modelName := range toolSupportedModels {
modelSet[modelName] = struct{}{}
}
for modelName := range modelSet {
// Determine model type
modelType := "chat"
if slices.Contains(mediaSupportedModels, modelName) && !slices.Contains(toolSupportedModels, modelName) {
modelType = "vision"
}
modelDef := ModelDefinition{
Name: modelName,
Type: modelType,
}
o.DefineModel(g, modelDef, nil)
log.Info().
Str("method", "Ollama.Init").
Str("model_name", modelName).
Str("model_type", modelType).
Msg("Registered model")
}
log.Info().Str("method", "Ollama.Init").Msg("Initialization successful")
// Register all supported models
modelSet := make(map[string]struct{})
for _, modelName := range mediaSupportedModels {
modelSet[modelName] = struct{}{}
}
for _, modelName := range toolSupportedModels {
modelSet[modelName] = struct{}{}
}
for modelName := range modelSet {
// Determine model type
modelType := "chat"
if slices.Contains(mediaSupportedModels, modelName) && !slices.Contains(toolSupportedModels, modelName) {
modelType = "vision"
}
modelDef := ModelDefinition{
Name: modelName,
Type: modelType,
}
o.DefineModel(g, modelDef, nil)
log.Info().
Str("method", "Ollama.Init").
Str("model_name", modelName).
Str("model_type", modelType).
Msg("Registered model")
}
log.Info().Str("method", "Ollama.Init").Msg("Initialization successful")
return nil
}
......
......@@ -11,8 +11,6 @@ import (
const UserNameKey = "username"
const UserIdKey = "user_id"
// Data returns the content type and bytes of the media part.
func Data(p *ai.Part) (string, []byte, error) {
if !p.IsMedia() && !p.IsData() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment