Commit 6ea40366 authored by Wade's avatar Wade

add graph

parent ccab4bec
...@@ -8,5 +8,14 @@ curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/chat ...@@ -8,5 +8,14 @@ curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/chat
curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/indexDocuments
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "Document indexed successfully"}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/milvus" "github.com/wade-liwei/agentchat/plugins/milvus"
"github.com/firebase/genkit/go/plugins/evaluators" "github.com/firebase/genkit/go/plugins/evaluators"
...@@ -20,24 +21,24 @@ import ( ...@@ -20,24 +21,24 @@ import (
_ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档 _ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档
) )
// GraphKnowledge
type Input struct { type Input struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"` APIKey string `json:"apiKey,omitempty"`
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
} }
// DocumentInput 结构体用于文档索引接口 // DocumentInput 结构体用于文档索引接口
type DocumentInput struct { type DocumentInput struct {
Content string `json:"content"` Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func main() { func main() {
ctx := context.Background() ctx := context.Background()
ds := deepseek.DeepSeek{ ds := deepseek.DeepSeek{
...@@ -59,7 +60,11 @@ func main() { ...@@ -59,7 +60,11 @@ func main() {
}, },
} }
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds,&mil,&googlegenai.GoogleAI{APIKey:"AIzaSyCoYBOmnwRWlH_-nT25lpn8pMg3T18Q0uI"}, &evaluators.GenkitEval{Metrics: metrics})) graph := graphrag.GraphKnowledge{
Addr: "54.92.111.204:5670",
}
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds, &mil, &graph, &googlegenai.GoogleAI{APIKey: "AIzaSyCoYBOmnwRWlH_-nT25lpn8pMg3T18Q0uI"}, &evaluators.GenkitEval{Metrics: metrics}))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
...@@ -71,19 +76,16 @@ func main() { ...@@ -71,19 +76,16 @@ func main() {
}, },
nil) nil)
embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001")
if embedder == nil { if embedder == nil {
log.Fatal("embedder is not defined") log.Fatal("embedder is not defined")
} }
// Configure collection // Configure collection
cfg := milvus.CollectionConfig{ cfg := milvus.CollectionConfig{
Collection: "useridx", Collection: "useridx",
Dimension: 768, // Match mock embedder dimension Dimension: 768, // Match mock embedder dimension
Embedder: embedder, Embedder: embedder,
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
} }
...@@ -91,36 +93,47 @@ func main() { ...@@ -91,36 +93,47 @@ func main() {
indexer, retriever, err := milvus.DefineIndexerAndRetriever(ctx, g, cfg) indexer, retriever, err := milvus.DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil { if err != nil {
log.Fatalf("DefineIndexerAndRetriever failed: %v", err) log.Fatalf("DefineIndexerAndRetriever failed: %v", err)
} }
_ = retriever
_ = retriever
// 定义文档索引流 // 定义文档索引流
genkit.DefineFlow(g, "indexDocuments", func(ctx context.Context, input *DocumentInput) (string, error) { genkit.DefineFlow(g, "indexDocuments", func(ctx context.Context, input *DocumentInput) (string, error) {
doc := ai.DocumentFromText(input.Content, input.Metadata) doc := ai.DocumentFromText(input.Content, input.Metadata)
err := indexer.Index(ctx, &ai.IndexerRequest{ err := indexer.Index(ctx, &ai.IndexerRequest{
Documents:[]*ai.Document{doc}, Documents: []*ai.Document{doc},
}) })
if err != nil { if err != nil {
return "", fmt.Errorf("index document: %w", err) return "", fmt.Errorf("index document: %w", err)
} }
return "Document indexed successfully", nil return "Document indexed successfully", nil
}) })
graphIndexer, graphRetriever, err := graphrag.DefineIndexerAndRetriever(ctx, g)
_ = graphRetriever
genkit.DefineFlow(g, "indexGraph", func(ctx context.Context, input *DocumentInput) (string, error) {
doc := ai.DocumentFromText(input.Content, input.Metadata)
err := graphIndexer.Index(ctx, &ai.IndexerRequest{
Documents: []*ai.Document{doc},
})
if err != nil {
return "", fmt.Errorf("index document: %w", err)
}
return "Document indexed successfully", nil
})
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) {
inputAsJson, err := json.Marshal(input) inputAsJson, err := json.Marshal(input)
if err != nil{ if err != nil {
return "",err return "", err
} }
fmt.Println("input-------------------------------",string(inputAsJson))
fmt.Println("input-------------------------------", string(inputAsJson))
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
...@@ -152,15 +165,13 @@ func main() { ...@@ -152,15 +165,13 @@ func main() {
mux.Handle("POST /"+a.Name(), handler) mux.Handle("POST /"+a.Name(), handler)
} }
// 暴露 Swagger UI,使用 swagger.yaml
mux.HandleFunc("/swagger/", httpSwagger.Handler(
httpSwagger.URL("/docs/swagger.yaml"), // 指定 YAML 文件路径
))
// 暴露 Swagger UI,使用 swagger.yaml // 确保 docs 目录可通过 HTTP 访问
mux.HandleFunc("/swagger/", httpSwagger.Handler( mux.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs"))))
httpSwagger.URL("/docs/swagger.yaml"), // 指定 YAML 文件路径
))
// 确保 docs 目录可通过 HTTP 访问
mux.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs"))))
// 启动服务器,监听 // 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:8000") log.Printf("Server starting on 0.0.0.0:8000")
......
This diff is collapsed.
/*
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/space/add' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"id": 0,
"name": "string",
"vector_type": "string",
"domain_type": "Normal",
"desc": "string",
"owner": "string",
"space_id": 0
}'
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/111111/document/add' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"doc_name": "string",
"doc_id": 0,
"doc_type": "string",
"doc_token": "string",
"content": "string",
"source": "string",
"labels": "string",
"questions": [
"string"
]
}'
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/1111/document/sync_batch' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '[
{
"doc_id": 0,
"space_id": "string",
"model_name": "string",
"chunk_parameters": {
"chunk_strategy": "string",
"text_splitter": "string",
"splitter_type": "user_define",
"chunk_size": 512,
"chunk_overlap": 50,
"separator": "\n",
"enable_merge": true
}
}
]'
curl -X 'POST' \
'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"conv_uid": "",
"user_input": "",
"user_name": "string",
"chat_mode": "",
"app_code": "",
"temperature": 0.5,
"max_new_tokens": 4000,
"select_param": "string",
"model_name": "string",
"incremental": false,
"sys_code": "string",
"prompt_code": "string",
"ext_info": {}
}'
*/
package knowledge
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
// Client 知识库客户端
type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
}
// SpaceRequest 创建空间的请求结构体
type SpaceRequest struct {
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
}
// DocumentRequest 添加文档的请求结构体
type DocumentRequest struct {
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
}
// ChunkParameters 分片参数
type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
}
// SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct {
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
}
// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
}
// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
SILICONFLOW_API_KEY=sk-ogigzbogipwhtkwvwnoeiovjdalkotopnpkwkxlvsvjsmyms docker compose up -d
curl 'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'Accept-Language: zh-CN,zh;q=0.9' \
-H 'Connection: keep-alive' \
-H 'Content-Type: application/json' \
-b 'sb-eowcghempsnzqalhkois-auth-token=base64-eyJhY2Nlc3NfdG9rZW4iOiJleUpoYkdjaU9pSklVekkxTmlJc0ltdHBaQ0k2SWpkemVpdEtXQzlCU0ZaSFFUWnBhVE1pTENKMGVYQWlPaUpLVjFRaWZRLmV5SnBjM01pT2lKb2RIUndjem92TDJWdmQyTm5hR1Z0Y0hOdWVuRmhiR2hyYjJsekxuTjFjR0ZpWVhObExtTnZMMkYxZEdndmRqRWlMQ0p6ZFdJaU9pSXpaamRoWVdRMU1TMHpNVGd3TFRSbVlXRXRPV0ZsTnkwMlpXRm1PV001Tm1JNFpHVWlMQ0poZFdRaU9pSmhkWFJvWlc1MGFXTmhkR1ZrSWl3aVpYaHdJam94TnpRNE16TXpOREV5TENKcFlYUWlPakUzTkRnek1qazRNVElzSW1WdFlXbHNJam9pYkdsM1pXbGZkMkZrWlVCcFkyeHZkV1F1WTI5dElpd2ljR2h2Ym1VaU9pSWlMQ0poY0hCZmJXVjBZV1JoZEdFaU9uc2ljSEp2ZG1sa1pYSWlPaUpsYldGcGJDSXNJbkJ5YjNacFpHVnljeUk2V3lKbGJXRnBiQ0pkZlN3aWRYTmxjbDl0WlhSaFpHRjBZU0k2ZXlKbGJXRnBiQ0k2SW14cGQyVnBYM2RoWkdWQWFXTnNiM1ZrTG1OdmJTSXNJbVZ0WVdsc1gzWmxjbWxtYVdWa0lqcDBjblZsTENKd2FHOXVaVjkyWlhKcFptbGxaQ0k2Wm1Gc2MyVXNJbk4xWWlJNklqTm1OMkZoWkRVeExUTXhPREF0TkdaaFlTMDVZV1UzTFRabFlXWTVZemsyWWpoa1pTSjlMQ0p5YjJ4bElqb2lZWFYwYUdWdWRHbGpZWFJsWkNJc0ltRmhiQ0k2SW1GaGJERWlMQ0poYlhJaU9sdDdJbTFsZEdodlpDSTZJbkJoYzNOM2IzSmtJaXdpZEdsdFpYTjBZVzF3SWpveE56UTJOVE0zTXpNNWZWMHNJbk5sYzNOcGIyNWZhV1FpT2lKa1lUTmhOelpqTmkweVlXVXdMVFEyWVdNdFlUaGxOUzFtWVRCallUTXpObUl5TlRZaUxDSnBjMTloYm05dWVXMXZkWE1pT21aaGJITmxmUS52MEhKdFRHN3EzSnc4QkRqdlFrWm9pOTVKcnNVZGNUZi1FWjBvc2d6OEk0IiwidG9rZW5fdHlwZSI6ImJlYXJlciIsImV4cGlyZXNfaW4iOjM2MDAsImV4cGlyZXNfYXQiOjE3NDgzMzM0MTIsInJlZnJlc2hfdG9rZW4iOiJscmpzdGl5MnM0a2giLCJ1c2VyIjp7ImlkIjoiM2Y3YWFkNTEtMzE4MC00ZmFhLTlhZTctNmVhZjljOTZiOGRlIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsInJvbGUiOiJhdXRoZW50aWNhdGVkIiwiZW1haWwiOiJsaXdlaV93YWRlQGljbG91ZC5jb20iLCJlbWFpbF9jb25maXJtZWRfYXQiOiIyMDI1LTAyLTEyVDExOjA0OjQwLjIxNzkzOVoiLCJwaG9uZSI6IiIsImNvbmZpcm1hdGlvbl9zZW50X2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC41Mzk2MjhaIiwiY29uZmlybWVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDo0MC4yMTc5MzlaIiwibGFzdF9zaWduX2luX2F0IjoiMjAyNS0wNS0wNlQxMzo0NToyMC45MDA0MTRaIiwiYXBwX21ldGFkYXRhIjp7InByb3ZpZGVyIjoiZW1haWwiLCJwcm92aWRlcnMiOlsiZW1haWwiXX0sInVzZXJfbWV0YWRhdGEiOnsiZW1haWwiOiJsaXdlaV93YWRlQGljbG91ZC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwicGhvbmVfdmVyaWZpZWQiOmZhbHNlLCJzdWIiOiIzZjdhYWQ1MS0zMTgwLTRmYWEtOWFlNy02ZWFmOWM5NmI4ZGUifSwiaWRlbnRpdGllcyI6W3siaWRlbnRpdHlfaWQiOiI2MjFiYTUxZi0yYzYzLTQxOWMtOWI2OS0zYzUzYTc5NDlhMzkiLCJpZCI6IjNmN2FhZDUxLTMxODAtNGZhYS05YWU3LTZlYWY5Yzk2YjhkZSIsInVzZXJfaWQiOiIzZjdhYWQ1MS0zMTgwLTRmYWEtOWFlNy02ZWFmOWM5NmI4ZGUiLCJpZGVudGl0eV9kYXRhIjp7ImVtYWlsIjoibGl3ZWlfd2FkZUBpY2xvdWQuY29tIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsInBob25lX3ZlcmlmaWVkIjpmYWxzZSwic3ViIjoiM2Y3YWFkNTEtMzE4MC00ZmFhLTlhZTctNmVhZjljOTZiOGRlIn0sInByb3ZpZGVyIjoiZW1haWwiLCJsYXN0X3NpZ25faW5fYXQiOiIyMDI1LTAyLTEyVDExOjA0OjAwLjUxMTAwOVoiLCJjcmVhdGVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC41MTExNDRaIiwidXBkYXRlZF9hdCI6IjIwMjUtMDItMTJUMTE6MDQ6MDAuNTExMTQ0WiIsImVtYWlsIjoibGl3ZWlfd2FkZUBpY2xvdWQuY29tIn1dLCJjcmVhdGVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC40NDYwMzdaIiwidXBkYXRlZF9hdCI6IjIwMjUtMDUtMjdUMDc6MTA6MTIuMjM0MDUyWiIsImlzX2Fub255bW91cyI6ZmFsc2V9fQ' \
-H 'Origin: http://54.92.111.204:5670' \
-H 'Referer: http://54.92.111.204:5670/chat?scene=chat_knowledge&id=a6997f46-3f9b-11f0-b9d7-36eb2f648a81&knowledge_id=bbbbbb' \
-H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36' \
-H 'accept: text/event-stream' \
-H 'user-id;' \
--data-raw '{"chat_mode":"chat_knowledge","model_name":"Qwen/Qwen2.5-Coder-32B-Instruct","user_input":"你好","app_code":"chat_knowledge","temperature":0.6,"max_new_tokens":4000,"select_param":"bbbbbb","conv_uid":"a6997f46-3f9b-11f0-b9d7-36eb2f648a81"}'
curl -X 'POST' \
'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"conv_uid": "",
"user_input": "",
"user_name": "string",
"chat_mode": "",
"app_code": "",
"temperature": 0.5,
"max_new_tokens": 4000,
"select_param": "string",
"model_name": "string",
"incremental": false,
"sys_code": "string",
"prompt_code": "string",
"ext_info": {}
}'
...@@ -273,6 +273,7 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer { ...@@ -273,6 +273,7 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
func Retriever(g *genkit.Genkit, collection string) ai.Retriever { func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection) return genkit.LookupRetriever(g, provider, collection)
} }
/* /*
更新 删除 很少用到; 更新 删除 很少用到;
*/ */
......
...@@ -11,169 +11,169 @@ import ( ...@@ -11,169 +11,169 @@ import (
) )
var ( var (
connString = flag.String("dbconn", "", "database connection string") connString = flag.String("dbconn", "", "database connection string")
) )
type QA struct { type QA struct {
ID int64 // 主键 ID int64 // 主键
CreatedAt time.Time // 创建时间 CreatedAt time.Time // 创建时间
FromID *string // 可空的 from_id FromID *string // 可空的 from_id
From *string // 可空的 from From *string // 可空的 from
Question *string // 可空的问题 Question *string // 可空的问题
Answer *string // 可空的答案 Answer *string // 可空的答案
Summary *string // 可空的摘要 Summary *string // 可空的摘要
To *string // 可空的 to To *string // 可空的 to
ToID *string // 可空的 to_id ToID *string // 可空的 to_id
} }
// QAStore 定义 DAO 接口 // QAStore 定义 DAO 接口
type QAStore interface { type QAStore interface {
// GetLatestQA 从 qa_latest_from_id 视图读取指定 from_id 的最新记录 // GetLatestQA 从 qa_latest_from_id 视图读取指定 from_id 的最新记录
GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error)
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error) WriteQA(ctx context.Context, qa QA) (int64, error)
} }
// qaStore 是 QAStore 接口的实现 // qaStore 是 QAStore 接口的实现
type qaStore struct { type qaStore struct {
db *sql.DB db *sql.DB
} }
// NewQAStore 创建新的 QAStore 实例 // NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore { func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db} return &qaStore{db: db}
} }
// 初始化数据库连接并返回 QAStore // 初始化数据库连接并返回 QAStore
func InitQAStore() (QAStore, error) { func InitQAStore() (QAStore, error) {
// Supabase 提供的连接字符串 // Supabase 提供的连接字符串
connString := "postgresql://postgres.awcfgdodiuqnlsobcivq:P99IU9NEoDRPsBfb@aws-0-ap-southeast-1.pooler.supabase.com:5432/postgres" connString := "postgresql://postgres.awcfgdodiuqnlsobcivq:P99IU9NEoDRPsBfb@aws-0-ap-southeast-1.pooler.supabase.com:5432/postgres"
// 打开数据库连接 // 打开数据库连接
db, err := sql.Open("postgres", connString) db, err := sql.Open("postgres", connString)
if err != nil { if err != nil {
return nil, fmt.Errorf("open database: %w", err) return nil, fmt.Errorf("open database: %w", err)
} }
// 测试数据库连接 // 测试数据库连接
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
db.Close() db.Close()
return nil, fmt.Errorf("ping database: %w", err) return nil, fmt.Errorf("ping database: %w", err)
} }
// 返回 QAStore 实例 // 返回 QAStore 实例
return NewQAStore(db), nil return NewQAStore(db), nil
} }
func (s *qaStore) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) { func (s *qaStore) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) {
query := ` query := `
SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id
FROM qa_latest_from_id FROM qa_latest_from_id
WHERE from_id = $1 OR (from_id IS NULL AND $1 IS NULL)` WHERE from_id = $1 OR (from_id IS NULL AND $1 IS NULL)`
args := []interface{}{fromID} args := []interface{}{fromID}
if fromID == nil { if fromID == nil {
args = []interface{}{nil} args = []interface{}{nil}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, fmt.Errorf("query qa_latest_from_id: %w", err) return nil, fmt.Errorf("query qa_latest_from_id: %w", err)
} }
defer rows.Close() defer rows.Close()
var results []QA var results []QA
for rows.Next() { for rows.Next() {
var qa QA var qa QA
var question, answer, summary, from, to, fromIDVal, toIDVal sql.NullString var question, answer, summary, from, to, fromIDVal, toIDVal sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &question, &answer, &summary, &from, &to, &fromIDVal, &toIDVal); err != nil { if err := rows.Scan(&qa.ID, &qa.CreatedAt, &question, &answer, &summary, &from, &to, &fromIDVal, &toIDVal); err != nil {
return nil, fmt.Errorf("scan row: %w", err) return nil, fmt.Errorf("scan row: %w", err)
} }
if question.Valid { if question.Valid {
qa.Question = &question.String qa.Question = &question.String
} }
if answer.Valid { if answer.Valid {
qa.Answer = &answer.String qa.Answer = &answer.String
} }
if summary.Valid { if summary.Valid {
qa.Summary = &summary.String qa.Summary = &summary.String
} }
if from.Valid { if from.Valid {
qa.From = &from.String qa.From = &from.String
} }
if to.Valid { if to.Valid {
qa.To = &to.String qa.To = &to.String
} }
if fromIDVal.Valid { if fromIDVal.Valid {
qa.FromID = &fromIDVal.String qa.FromID = &fromIDVal.String
} }
if toIDVal.Valid { if toIDVal.Valid {
qa.ToID = &toIDVal.String qa.ToID = &toIDVal.String
} }
results = append(results, qa) results = append(results, qa)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err) return nil, fmt.Errorf("row iteration: %w", err)
} }
return results, nil return results, nil
} }
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) { func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 { if qa.ID != 0 {
// 更新记录 // 更新记录
query := ` query := `
UPDATE qa UPDATE qa
SET question = $1, answer = $2, summary = $3, "from" = $4, "to" = $5, from_id = $6, to_id = $7 SET question = $1, answer = $2, summary = $3, "from" = $4, "to" = $5, from_id = $6, to_id = $7
WHERE id = $8 WHERE id = $8
RETURNING id` RETURNING id`
var updatedID int64 var updatedID int64
err := s.db.QueryRowContext(ctx, query, err := s.db.QueryRowContext(ctx, query,
derefString(qa.Question), derefString(qa.Question),
derefString(qa.Answer), derefString(qa.Answer),
derefString(qa.Summary), derefString(qa.Summary),
derefString(qa.From), derefString(qa.From),
derefString(qa.To), derefString(qa.To),
derefString(qa.FromID), derefString(qa.FromID),
derefString(qa.ToID), derefString(qa.ToID),
qa.ID, qa.ID,
).Scan(&updatedID) ).Scan(&updatedID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID) return 0, fmt.Errorf("no record found with id %d", qa.ID)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("update qa: %w", err) return 0, fmt.Errorf("update qa: %w", err)
} }
return updatedID, nil return updatedID, nil
} }
// 插入新记录 // 插入新记录
query := ` query := `
INSERT INTO qa (question, answer, summary, "from", "to", from_id, to_id) INSERT INTO qa (question, answer, summary, "from", "to", from_id, to_id)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id` RETURNING id`
var newID int64 var newID int64
err := s.db.QueryRowContext(ctx, query, err := s.db.QueryRowContext(ctx, query,
derefString(qa.Question), derefString(qa.Question),
derefString(qa.Answer), derefString(qa.Answer),
derefString(qa.Summary), derefString(qa.Summary),
derefString(qa.From), derefString(qa.From),
derefString(qa.To), derefString(qa.To),
derefString(qa.FromID), derefString(qa.FromID),
derefString(qa.ToID), derefString(qa.ToID),
).Scan(&newID) ).Scan(&newID)
if err != nil { if err != nil {
return 0, fmt.Errorf("insert qa: %w", err) return 0, fmt.Errorf("insert qa: %w", err)
} }
return newID, nil return newID, nil
} }
// 辅助函数:处理指针类型的空值 // 辅助函数:处理指针类型的空值
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }
func derefString(p *string) interface{} { func derefString(p *string) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment