Commit 8c0bff12 authored by Wade's avatar Wade

add models param

parent ba8c3fb3
...@@ -257,10 +257,12 @@ paths: ...@@ -257,10 +257,12 @@ paths:
type: string type: string
description: The chat message content description: The chat message content
example: "What is the capital of UK?" example: "What is the capital of UK?"
model: models:
type: array
items:
type: string type: string
description: The model to use for the chat response description: The models to use for the chat response
example: "deepseek/deepseek-chat" example: ["deepseek/deepseek-chat", "ollama/llama3.1"]
apiKey: apiKey:
type: string type: string
description: The API key for authentication description: The API key for authentication
...@@ -275,7 +277,8 @@ paths: ...@@ -275,7 +277,8 @@ paths:
example: "user123" example: "user123"
to: to:
type: string type: string
description: The recipient of the chat message example Bob description: The recipient of the chat message
example: "Bob"
to_id: to_id:
type: string type: string
description: The unique identifier for the recipient description: The unique identifier for the recipient
......
...@@ -284,11 +284,12 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -284,11 +284,12 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
promptInput.Summary = *lastQa.Summary promptInput.Summary = *lastQa.Summary
} }
if input.Milvus {
metaData := make(map[string]any) metaData := make(map[string]any)
metaData[util.UserIdKey] = input.ToID metaData[util.UserIdKey] = input.ToID
metaData[util.UserNameKey] = input.To metaData[util.UserNameKey] = input.To
dRequest := ai.DocumentFromText(input.Content, metaData) dRequest := ai.DocumentFromText(input.Content, metaData)
if input.Milvus {
response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest)) response, err := ai.Retrieve(ctx, retriever, ai.WithDocs(dRequest))
if err != nil { if err != nil {
log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error()) log.Error().Msgf("milvus Retrieve err.Error() %s", err.Error())
...@@ -319,22 +320,64 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai. ...@@ -319,22 +320,64 @@ func DefineChatFlow(g *genkit.Genkit, retriever ai.Retriever, graphRetriever ai.
} }
} }
simpleQaPrompt, err := defineSimpleQaPrompt(g, input.Model) resp := &ai.ModelResponse{}
var lastErr error
for i, model := range input.Models {
simpleQaPrompt, err := defineSimpleQaPrompt(g, model)
if err != nil { if err != nil {
// 打印错误日志
log.Error().
Str("model", model).
Int("index", i).
Err(err).
Msg("Failed to define simple QA prompt")
// 如果是最后一个模型,返回错误
if i == len(input.Models)-1 {
return Response{ return Response{
Code: 500, Code: 500,
Msg: fmt.Sprintf("index document: %w", err), Msg: fmt.Sprintf("index document: %w", err),
}, nil }, nil
} }
// 记录错误,继续下一个模型
lastErr = err
continue
}
resp, err := simpleQaPrompt.Execute(ctx, ai.WithInput(promptInput)) respTemp, err := simpleQaPrompt.Execute(ctx, ai.WithInput(promptInput))
if err != nil { if err != nil {
// 打印错误日志
log.Error().
Str("model", model).
Int("index", i).
Err(err).
Msg("Failed to execute prompt")
// 如果是最后一个模型,返回错误
if i == len(input.Models)-1 {
return Response{ return Response{
Code: 500, Code: 500,
Msg: fmt.Sprintf("index document: %w", err), Msg: fmt.Sprintf("index document: %w", err),
}, nil }, nil
} }
// 记录错误,继续下一个模型
lastErr = err
continue
}
// 成功执行,更新 resp
resp = respTemp
break
}
// 所有模型处理完成,检查是否全部失败
if resp == nil && lastErr != nil {
// 如果没有成功的结果,返回最后一个错误
return Response{
Code: 500,
Msg: fmt.Sprintf("index document: %w", lastErr),
}, nil
}
if lastok { if lastok {
......
...@@ -30,7 +30,7 @@ type ChatInput struct { ...@@ -30,7 +30,7 @@ type ChatInput struct {
To string `json:"to,"` To string `json:"to,"`
ToID string `json:"to_id,omitempty"` ToID string `json:"to_id,omitempty"`
// //
Model string `json:"model,omitempty"` Models []string `json:"models,omitempty"`
APIKey string `json:"apiKey,omitempty"` APIKey string `json:"apiKey,omitempty"`
Milvus bool `json:"milvus,omitempty"` Milvus bool `json:"milvus,omitempty"`
Graph bool `json:"graph,omitempty"` Graph bool `json:"graph,omitempty"`
...@@ -62,6 +62,7 @@ type simpleQaPromptInput struct { ...@@ -62,6 +62,7 @@ type simpleQaPromptInput struct {
func main() { func main() {
// Define command-line flags with hardcoded values as defaults // Define command-line flags with hardcoded values as defaults
ollamaServerAddress := flag.String("ollama-server-address", "http://localhost:11434", "Ollama server address")
deepseekAPIKey := flag.String("deepseek-api-key", "sk-9f70df871a7c4b8aa566a3c7a0603706", "DeepSeek API key") deepseekAPIKey := flag.String("deepseek-api-key", "sk-9f70df871a7c4b8aa566a3c7a0603706", "DeepSeek API key")
milvusAddr := flag.String("milvus-addr", "54.92.111.204:19530", "Milvus server address") milvusAddr := flag.String("milvus-addr", "54.92.111.204:19530", "Milvus server address")
graphragAddr := flag.String("graphrag-addr", "54.92.111.204:5670", "GraphRAG server address") graphragAddr := flag.String("graphrag-addr", "54.92.111.204:5670", "GraphRAG server address")
...@@ -83,7 +84,7 @@ func main() { ...@@ -83,7 +84,7 @@ func main() {
// Initialize genkit with plugins using flag/env values // Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins( g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"}, &ollama.Ollama{ServerAddress: *ollamaServerAddress},
&deepseek.DeepSeek{APIKey: *deepseekAPIKey}, &deepseek.DeepSeek{APIKey: *deepseekAPIKey},
&milvus.Milvus{Addr: *milvusAddr}, &milvus.Milvus{Addr: *milvusAddr},
&graphrag.GraphKnowledge{Addr: *graphragAddr}, &graphrag.GraphKnowledge{Addr: *graphragAddr},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment