Commit 89fcd46e authored by Wade's avatar Wade

update qa

parent 1e8a8704
This diff is collapsed.
...@@ -152,7 +152,7 @@ paths: ...@@ -152,7 +152,7 @@ paths:
/chat: /chat:
post: post:
summary: Send a chat message summary: Send a chat message
description: Sends a chat message to the Genkit AI workflow and returns a response. description: Sends a chat message to the Genkit AI workflow and returns a response, with optional Milvus and GraphRAG indexing flags.
tags: tags:
- Chat - Chat
requestBody: requestBody:
...@@ -184,12 +184,19 @@ paths: ...@@ -184,12 +184,19 @@ paths:
example: "user123" example: "user123"
to: to:
type: string type: string
description: The recipient of the chat message description: The recipient of the chat message example Bob
example: "Bob"
to_id: to_id:
type: string type: string
description: The unique identifier for the recipient description: The unique identifier for the recipient
example: "user456" example: "user456"
milvus:
type: boolean
description: Whether to use Milvus indexing for the chat content
example: true
graph:
type: boolean
description: Whether to use GraphRAG indexing for the chat content
example: false
required: required:
- content - content
responses: responses:
......
package main
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/milvus"
)
func startServer(g *genkit.Genkit, db *sql.DB, indexer ai.Indexer, retriever ai.Retriever, embedder ai.Embedder, apiKey string) {
http.HandleFunc("/idx/milvus", handleIndex(indexer, embedder, "Milvus"))
http.HandleFunc("/idx/graphrag", handleIndex(indexer, embedder, "GraphRAG")) // 如果需要支持 GraphRAG
http.HandleFunc("/index", handleIndexTrigger(g, db, indexer, apiKey))
http.HandleFunc("/askQuestion", handleAskQuestion(g, retriever))
addr := fmt.Sprintf(":%s", *port)
log.Printf("Starting server on %s", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
func handleIndex(indexer ai.Indexer, embedder ai.Embedder, indexType string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
var req IndexRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"Invalid request body"}`, http.StatusBadRequest)
return
}
if req.Question == "" || req.Answer == "" {
http.Error(w, `{"error":"Missing required fields: question and answer"}`, http.StatusBadRequest)
return
}
var userID *int64
if req.UserID != nil {
id, err := strconv.ParseInt(*req.UserID, 10, 64)
if err != nil {
http.Error(w, `{"error":"Invalid user_id format"}`, http.StatusBadRequest)
return
}
userID = &id
}
// 构造文本内容
text := req.Question + " " + req.Answer
if req.Summary != nil {
text += " " + *req.Summary
}
// 构造元数据
metadata := map[string]interface{}{
"username": req.Username,
"user_id": userID,
}
// 生成唯一 ID(由于 Milvus 插件的 schema idField AutoID,这里仅用于响应)
id := time.Now().UnixNano()
// 创建文档
doc := &ai.Document{
Content: []*ai.Part{ai.NewTextPart(text)},
Metadata: metadata,
}
// 使用 Indexer 写入 Milvus
err := ai.Index(r.Context(), indexer, ai.WithDocs(doc))
if err != nil {
log.Printf("Failed to index %s data: %v", indexType, err)
http.Error(w, `{"error":"Failed to store data"}`, http.StatusInternalServerError)
return
}
resp := IndexResponse{ID: id}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Failed to encode response: %v", err)
}
}
}
func handleIndexTrigger(g *genkit.Genkit, db *sql.DB, indexer ai.Indexer, expectedAPIKey string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
var req IndexTriggerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"Invalid request body"}`, http.StatusBadRequest)
return
}
if req.APIKey != expectedAPIKey {
http.Error(w, `{"error":"Invalid API key"}`, http.StatusBadRequest)
return
}
if err := indexExistingRows(r.Context(), db, indexer); err != nil {
log.Printf("Failed to index data: %v", err)
http.Error(w, `{"error":"Failed to index data"}`, http.StatusInternalServerError)
return
}
resp := IndexTriggerResponse{Message: "Indexing completed successfully"}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Failed to encode response: %v", err)
}
}
}
func handleAskQuestion(g *genkit.Genkit, retriever ai.Retriever) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
var input struct {
Question string `json:"Question"`
Show string `json:"Show"`
}
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
http.Error(w, `{"error":"Invalid request body"}`, http.StatusBadRequest)
return
}
if input.Question == "" || input.Show == "" {
http.Error(w, `{"error":"Missing required fields: Question and Show"}`, http.StatusBadRequest)
return
}
// 创建查询文档
queryDoc := &ai.Document{
Content: []*ai.Part{ai.NewTextPart(input.Question)},
}
// 使用 Retriever 检索
retrieverOptions := &milvus.RetrieverOptions{
Count: 3, // 获取前 3 个结果
MetricType: "L2",
}
result, err := ai.Retrieve(r.Context(), retriever, ai.WithQuery(queryDoc), ai.WithOptions(retrieverOptions))
if err != nil {
log.Printf("Failed to retrieve data: %v", err)
http.Error(w, `{"error":"Failed to process question"}`, http.StatusInternalServerError)
return
}
// 构造响应(可以根据需要处理检索结果)
var responseText string
for _, doc := range result.Documents {
for _, part := range doc.Content {
if part.IsText() {
responseText += part.Text + "\n"
}
}
}
resp := struct {
Response string `json:"response"`
}{Response: responseText}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Failed to encode response: %v", err)
}
}
}
func indexExistingRows(ctx context.Context, db *sql.DB, indexer ai.Indexer) error {
rows, err := db.QueryContext(ctx, `SELECT id, question, answer, summary FROM qa`)
if err != nil {
return err
}
defer rows.Close()
var docs []*ai.Document
for rows.Next() {
var id int64
var question, answer, summary sql.NullString
if err := rows.Scan(&id, &question, &answer, &summary); err != nil {
return err
}
content := question.String
if answer.Valid {
content += " " + answer.String
}
if summary.Valid {
content += " " + summary.String
}
docs = append(docs, &ai.Document{
Content: []*ai.Part{ai.NewTextPart(content)},
Metadata: map[string]interface{}{
"id": id,
},
})
}
if err := rows.Err(); err != nil {
return err
}
return ai.Index(ctx, indexer, ai.WithDocs(docs...))
}
\ No newline at end of file
...@@ -32,12 +32,15 @@ import ( ...@@ -32,12 +32,15 @@ import (
type ChatInput struct { type ChatInput struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"`
From string `json:"from,omitempty"` // 替换 Username From string `json:"from,omitempty"` // 替换 Username
FromID string `json:"from_id,omitempty"` // 替换 UserID FromID string `json:"from_id,omitempty"` // 替换 UserID
To string `json:"to,"` To string `json:"to,"`
ToID string `json:"to_id,omitempty"` ToID string `json:"to_id,omitempty"`
//
Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Milvus bool `json:"milvus,omitempty"`
Graph bool `json:"graph,omitempty"`
} }
// DocumentInput 结构体用于文档索引接口 // DocumentInput 结构体用于文档索引接口
...@@ -283,6 +286,8 @@ func main() { ...@@ -283,6 +286,8 @@ func main() {
Question: &input.Content, Question: &input.Content,
To: &input.To, To: &input.To,
ToID: &input.ToID, ToID: &input.ToID,
Milvus: &input.Milvus,
Graph: &input.Graph,
}) })
if err != nil { if err != nil {
...@@ -358,18 +363,18 @@ func main() { ...@@ -358,18 +363,18 @@ func main() {
if lastok { if lastok {
if promptInput.Summary == ""{ if promptInput.Summary == "" {
promptInput.Summary = resp.Text() promptInput.Summary = resp.Text()
} }
log.Info(). log.Info().
Str("from",input.From). Str("from", input.From).
Str("from_id",input.FromID). Str("from_id", input.FromID).
Str("to",input.To). Str("to", input.To).
Str("to_id",input.ToID). Str("to_id", input.ToID).
Str("promptInput.Query",promptInput.Query). Str("promptInput.Query", promptInput.Query).
Str("resp.Text()",resp.Text()). Str("resp.Text()", resp.Text()).
Str("promptInput.Summary",promptInput.Summary).Msg("QueryRewriteWithSummary") Str("promptInput.Summary", promptInput.Summary).Msg("QueryRewriteWithSummary")
res, err := kc.QueryRewriteWithSummary(context.Background(), promptInput.Query, resp.Text(), promptInput.Summary) res, err := kc.QueryRewriteWithSummary(context.Background(), promptInput.Query, resp.Text(), promptInput.Summary)
...@@ -378,7 +383,7 @@ func main() { ...@@ -378,7 +383,7 @@ func main() {
} else { } else {
qa.UpdateQAFields(context.Background(), idx, res.RewrittenQuery, resp.Text()) qa.UpdateQAFields(context.Background(), idx, res.RewrittenQuery, resp.Text())
/* /*
{"RewrittenQuery":"Conversation summary: The available knowledge base does not contain information about the capital of the UK.","RawResponse":{"Response":{"Content":"Conversation summary: The available knowledge base does not contain information about the capital of the UK.","Usage":{"InputTokens":74,"OutputTokens":19,"TotalTokens":93},"RequestId":"15f1ce0c-a83f-4d95-af22-33a3bd829e8d"}}} {"RewrittenQuery":"Conversation summary: The available knowledge base does not contain information about the capital of the UK.","RawResponse":{"Response":{"Content":"Conversation summary: The available knowledge base does not contain information about the capital of the UK.","Usage":{"InputTokens":74,"OutputTokens":19,"TotalTokens":93},"RequestId":"15f1ce0c-a83f-4d95-af22-33a3bd829e8d"}}}
*/ */
} }
} else { } else {
...@@ -386,10 +391,10 @@ func main() { ...@@ -386,10 +391,10 @@ func main() {
} }
log.Info(). log.Info().
Str("from",input.From). Str("from", input.From).
Str("from_id",input.FromID). Str("from_id", input.FromID).
Str("to",input.To). Str("to", input.To).
Str("to_id",input.ToID). Str("to_id", input.ToID).
Str("question", promptInput.Query). Str("question", promptInput.Query).
Str("context", promptInput.Context). Str("context", promptInput.Context).
Str("graph", promptInput.Graph). Str("graph", promptInput.Graph).
......
...@@ -144,8 +144,8 @@ func TestKnowledgeClient_QueryRewriteWithSummary(t *testing.T) { ...@@ -144,8 +144,8 @@ func TestKnowledgeClient_QueryRewriteWithSummary(t *testing.T) {
name: "ValidWithSummary", name: "ValidWithSummary",
userQuestion: "你的家在哪里", userQuestion: "你的家在哪里",
assistantAnswer: "国内", assistantAnswer: "国内",
historySummary: "null", //"User asked about location preferences earlier.", historySummary: "null", //"User asked about location preferences earlier.",
expectError: true, // Expect error due to potentially invalid credentials expectError: true, // Expect error due to potentially invalid credentials
}, },
} }
......
...@@ -8,10 +8,13 @@ create table public.qa ( ...@@ -8,10 +8,13 @@ create table public.qa (
"to" text null, "to" text null,
from_id text null, from_id text null,
to_id text null, to_id text null,
milvus boolean null default false,
graph boolean null default false,
constraint qa_pkey primary key (id) constraint qa_pkey primary key (id)
) TABLESPACE pg_default; ) TABLESPACE pg_default;
CREATE VIEW public.qa_latest_from_id AS CREATE VIEW public.qa_latest_from_id AS
SELECT DISTINCT ON (qa.from_id) SELECT DISTINCT ON (qa.from_id)
qa.id, qa.id,
...@@ -22,7 +25,9 @@ SELECT DISTINCT ON (qa.from_id) ...@@ -22,7 +25,9 @@ SELECT DISTINCT ON (qa.from_id)
qa."from", qa."from",
qa."to", qa."to",
qa.from_id, qa.from_id,
qa.to_id qa.to_id,
qa.milvus,
qa.graph
FROM FROM
qa qa
WHERE WHERE
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment