Commit 6ea40366 authored by Wade's avatar Wade

add graph

parent ccab4bec
...@@ -8,5 +8,14 @@ curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/chat ...@@ -8,5 +8,14 @@ curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/chat
curl -d '{"content": "What is the capital of UK?"}' http://localhost:8000/indexDocuments
curl -X POST http://localhost:8000/indexDocuments \
-H "Content-Type: application/json" \
-d '{"content": "What is the capital of UK?", "metadata": {"user_id": "user456", "username": "Bob"}}'
{"result": "Document indexed successfully"}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek" "github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/milvus" "github.com/wade-liwei/agentchat/plugins/milvus"
"github.com/firebase/genkit/go/plugins/evaluators" "github.com/firebase/genkit/go/plugins/evaluators"
...@@ -20,24 +21,24 @@ import ( ...@@ -20,24 +21,24 @@ import (
_ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档 _ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档
) )
// GraphKnowledge
type Input struct { type Input struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"` APIKey string `json:"apiKey,omitempty"`
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
} }
// DocumentInput 结构体用于文档索引接口 // DocumentInput 结构体用于文档索引接口
type DocumentInput struct { type DocumentInput struct {
Content string `json:"content"` Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func main() { func main() {
ctx := context.Background() ctx := context.Background()
ds := deepseek.DeepSeek{ ds := deepseek.DeepSeek{
...@@ -59,7 +60,11 @@ func main() { ...@@ -59,7 +60,11 @@ func main() {
}, },
} }
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds,&mil,&googlegenai.GoogleAI{APIKey:"AIzaSyCoYBOmnwRWlH_-nT25lpn8pMg3T18Q0uI"}, &evaluators.GenkitEval{Metrics: metrics})) graph := graphrag.GraphKnowledge{
Addr: "54.92.111.204:5670",
}
g, err := genkit.Init(ctx, genkit.WithPlugins(&ds, &mil, &graph, &googlegenai.GoogleAI{APIKey: "AIzaSyCoYBOmnwRWlH_-nT25lpn8pMg3T18Q0uI"}, &evaluators.GenkitEval{Metrics: metrics}))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
...@@ -71,19 +76,16 @@ func main() { ...@@ -71,19 +76,16 @@ func main() {
}, },
nil) nil)
embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001")
if embedder == nil { if embedder == nil {
log.Fatal("embedder is not defined") log.Fatal("embedder is not defined")
} }
// Configure collection // Configure collection
cfg := milvus.CollectionConfig{ cfg := milvus.CollectionConfig{
Collection: "useridx", Collection: "useridx",
Dimension: 768, // Match mock embedder dimension Dimension: 768, // Match mock embedder dimension
Embedder: embedder, Embedder: embedder,
EmbedderOptions: map[string]interface{}{}, // Explicitly set as map EmbedderOptions: map[string]interface{}{}, // Explicitly set as map
} }
...@@ -91,36 +93,47 @@ func main() { ...@@ -91,36 +93,47 @@ func main() {
indexer, retriever, err := milvus.DefineIndexerAndRetriever(ctx, g, cfg) indexer, retriever, err := milvus.DefineIndexerAndRetriever(ctx, g, cfg)
if err != nil { if err != nil {
log.Fatalf("DefineIndexerAndRetriever failed: %v", err) log.Fatalf("DefineIndexerAndRetriever failed: %v", err)
} }
_ = retriever
_ = retriever
// 定义文档索引流 // 定义文档索引流
genkit.DefineFlow(g, "indexDocuments", func(ctx context.Context, input *DocumentInput) (string, error) { genkit.DefineFlow(g, "indexDocuments", func(ctx context.Context, input *DocumentInput) (string, error) {
doc := ai.DocumentFromText(input.Content, input.Metadata) doc := ai.DocumentFromText(input.Content, input.Metadata)
err := indexer.Index(ctx, &ai.IndexerRequest{ err := indexer.Index(ctx, &ai.IndexerRequest{
Documents:[]*ai.Document{doc}, Documents: []*ai.Document{doc},
}) })
if err != nil { if err != nil {
return "", fmt.Errorf("index document: %w", err) return "", fmt.Errorf("index document: %w", err)
} }
return "Document indexed successfully", nil return "Document indexed successfully", nil
}) })
graphIndexer, graphRetriever, err := graphrag.DefineIndexerAndRetriever(ctx, g)
_ = graphRetriever
genkit.DefineFlow(g, "indexGraph", func(ctx context.Context, input *DocumentInput) (string, error) {
doc := ai.DocumentFromText(input.Content, input.Metadata)
err := graphIndexer.Index(ctx, &ai.IndexerRequest{
Documents: []*ai.Document{doc},
})
if err != nil {
return "", fmt.Errorf("index document: %w", err)
}
return "Document indexed successfully", nil
})
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) {
inputAsJson, err := json.Marshal(input) inputAsJson, err := json.Marshal(input)
if err != nil{ if err != nil {
return "",err return "", err
} }
fmt.Println("input-------------------------------",string(inputAsJson))
fmt.Println("input-------------------------------", string(inputAsJson))
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
...@@ -152,15 +165,13 @@ func main() { ...@@ -152,15 +165,13 @@ func main() {
mux.Handle("POST /"+a.Name(), handler) mux.Handle("POST /"+a.Name(), handler)
} }
// 暴露 Swagger UI,使用 swagger.yaml
mux.HandleFunc("/swagger/", httpSwagger.Handler(
httpSwagger.URL("/docs/swagger.yaml"), // 指定 YAML 文件路径
))
// 暴露 Swagger UI,使用 swagger.yaml // 确保 docs 目录可通过 HTTP 访问
mux.HandleFunc("/swagger/", httpSwagger.Handler( mux.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs"))))
httpSwagger.URL("/docs/swagger.yaml"), // 指定 YAML 文件路径
))
// 确保 docs 目录可通过 HTTP 访问
mux.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs"))))
// 启动服务器,监听 // 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:8000") log.Printf("Server starting on 0.0.0.0:8000")
......
/* // Copyright 2025 Google LLC
//
curl -X 'POST' \ // Licensed under the Apache License, Version 2.0 (the "License");
'http://54.92.111.204:5670/knowledge/space/add' \ // you may not use this file except in compliance with the License.
-H 'accept: application/json' \ // You may obtain a copy of the License at
-H 'Content-Type: application/json' \ //
-d '{ // http://www.apache.org/licenses/LICENSE-2.0
"id": 0, //
"name": "string", // Unless required by applicable law or agreed to in writing, software
"vector_type": "string", // distributed under the License is distributed on an "AS IS" BASIS,
"domain_type": "Normal", // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"desc": "string", // See the License for the specific language governing permissions and
"owner": "string", // limitations under the License.
"space_id": 0 //
}' // SPDX-License-Identifier: Apache-2.0
package graphrag
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/111111/document/add' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"doc_name": "string",
"doc_id": 0,
"doc_type": "string",
"doc_token": "string",
"content": "string",
"source": "string",
"labels": "string",
"questions": [
"string"
]
}'
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/1111/document/sync_batch' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '[
{
"doc_id": 0,
"space_id": "string",
"model_name": "string",
"chunk_parameters": {
"chunk_strategy": "string",
"text_splitter": "string",
"splitter_type": "user_define",
"chunk_size": 512,
"chunk_overlap": 50,
"separator": "\n",
"enable_merge": true
}
}
]'
*/
package knowledge
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv"
"strings"
"sync"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
) )
// Client 知识库客户端 // Client 知识库客户端
...@@ -87,14 +50,15 @@ type SpaceRequest struct { ...@@ -87,14 +50,15 @@ type SpaceRequest struct {
// DocumentRequest 添加文档的请求结构体 // DocumentRequest 添加文档的请求结构体
type DocumentRequest struct { type DocumentRequest struct {
DocName string `json:"doc_name"` DocName string `json:"doc_name"`
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
DocType string `json:"doc_type"` DocType string `json:"doc_type"`
DocToken string `json:"doc_token"` DocToken string `json:"doc_token"`
Content string `json:"content"` Content string `json:"content"`
Source string `json:"source"` Source string `json:"source"`
Labels string `json:"labels"` Labels string `json:"labels"`
Questions []string `json:"questions"` Questions []string `json:"questions"`
Metadata map[string]interface{} `json:"metadata"`
} }
// ChunkParameters 分片参数 // ChunkParameters 分片参数
...@@ -108,7 +72,7 @@ type ChunkParameters struct { ...@@ -108,7 +72,7 @@ type ChunkParameters struct {
EnableMerge bool `json:"enable_merge"` EnableMerge bool `json:"enable_merge"`
} }
// SyncBatchRequest 同步批处理的请求结构体 // SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct { type SyncBatchRequest struct {
DocID int `json:"doc_id"` DocID int `json:"doc_id"`
SpaceID string `json:"space_id"` SpaceID string `json:"space_id"`
...@@ -173,7 +137,7 @@ func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Respons ...@@ -173,7 +137,7 @@ func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Respons
return resp, nil return resp, nil
} }
// SyncBatchDocument 同步批处理文档 // SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) { func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID) url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req) body, err := json.Marshal(req)
...@@ -197,3 +161,349 @@ func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*htt ...@@ -197,3 +161,349 @@ func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*htt
return resp, nil return resp, nil
} }
// The provider used in the registry.
const provider = "graphrag"
// Field names for schema.
const (
idField = "id"
textField = "text"
metadataField = "metadata"
)
// GraphKnowledge holds configuration for the plugin.
type GraphKnowledge struct {
Addr string // Knowledge server address (host:port, e.g., "54.92.111.204:5670").
client *Client // Knowledge client.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
// Name returns the plugin name.
func (k *GraphKnowledge) Name() string {
return provider
}
// Init initializes the GraphKnowledge plugin.
func (k *GraphKnowledge) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if k == nil {
k = &GraphKnowledge{}
}
k.mu.Lock()
defer k.mu.Unlock()
defer func() {
if err != nil {
err = fmt.Errorf("graphrag.Init: %w", err)
}
}()
if k.initted {
return errors.New("plugin already initialized")
}
// Load configuration.
addr := k.Addr
if addr == "" {
addr = "54.92.111.204:5670" // Default address.
}
// Initialize Knowledge client.
host, port := parseAddr(addr)
client := NewClient(host, port)
k.client = client
k.initted = true
return nil
}
// parseAddr splits host:port into host and port.
func parseAddr(addr string) (string, int) {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
return "54.92.111.204", 5670
}
port, _ := strconv.Atoi(parts[1])
return parts[0], port
}
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Knowledge space.
func DefineIndexerAndRetriever(ctx context.Context, g *genkit.Genkit) (ai.Indexer, ai.Retriever, error) {
spaceID := ""
modelName := ""
k := genkit.LookupPlugin(g, provider)
if k == nil {
return nil, nil, errors.New("graphrag plugin not found; did you call genkit.Init with the graphrag plugin?")
}
knowledge := k.(*GraphKnowledge)
ds, err := knowledge.newDocStore(ctx, spaceID, modelName)
if err != nil {
return nil, nil, err
}
indexer := genkit.DefineIndexer(g, provider, spaceID, ds.Index)
retriever := genkit.DefineRetriever(g, provider, spaceID, ds.Retrieve)
return indexer, retriever, nil
}
// docStore defines an Indexer and a Retriever.
type docStore struct {
client *Client
spaceID string
modelName string
}
// newDocStore creates a docStore.
func (k *GraphKnowledge) newDocStore(ctx context.Context, spaceID, modelName string) (*docStore, error) {
if k.client == nil {
return nil, errors.New("graphrag.Init not called")
}
return &docStore{
client: k.client,
spaceID: spaceID,
modelName: modelName,
}, nil
}
// Indexer returns the indexer for a space.
func Indexer(g *genkit.Genkit, spaceID string) ai.Indexer {
return genkit.LookupIndexer(g, provider, spaceID)
}
// Retriever returns the retriever for a space.
func Retriever(g *genkit.Genkit, spaceID string) ai.Retriever {
return genkit.LookupRetriever(g, provider, spaceID)
}
// Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 {
return nil
}
// Create knowledge space.
spaceReq := SpaceRequest{
ID: 1,
Name: ds.spaceID,
VectorType: "hnsw",
DomainType: "Normal",
Desc: "Default knowledge space",
Owner: "admin",
SpaceID: 1,
}
resp, err := ds.client.AddSpace(spaceReq)
if err != nil {
return fmt.Errorf("add space: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("add space failed with status %d: %s", resp.StatusCode, string(body))
}
// Index each document.
for i, doc := range req.Documents {
// Ensure metadata includes user_id and username.
if doc.Metadata == nil {
doc.Metadata = make(map[string]interface{})
}
if _, ok := doc.Metadata["user_id"]; !ok {
doc.Metadata["user_id"] = "user123" // Mock data.
}
if _, ok := doc.Metadata["username"]; !ok {
doc.Metadata["username"] = "Alice" // Mock data.
}
// Add document.
var sb strings.Builder
for _, p := range doc.Content {
if p.IsText() {
sb.WriteString(p.Text)
}
}
text := sb.String()
docReq := DocumentRequest{
DocName: fmt.Sprintf("doc_%d", i+1),
DocID: i + 1,
DocType: "text",
DocToken: "",
Content: text,
Source: "api",
Labels: "",
Questions: []string{},
Metadata: doc.Metadata,
}
resp, err := ds.client.AddDocument(ds.spaceID, docReq)
if err != nil {
return fmt.Errorf("add document %d: %w", i+1, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("add document %d failed with status %d: %s", i+1, resp.StatusCode, string(body))
}
// Sync document for embedding.
syncReq := []SyncBatchRequest{
{
DocID: docReq.DocID,
SpaceID: ds.spaceID,
ModelName: ds.modelName,
ChunkParameters: ChunkParameters{
ChunkStrategy: "sentence",
TextSplitter: "recursive",
SplitterType: "user_define",
ChunkSize: 512,
ChunkOverlap: 50,
Separator: "\n",
EnableMerge: true,
},
},
}
syncResp, err := ds.client.SyncBatchDocument(ds.spaceID, syncReq)
if err != nil {
return fmt.Errorf("sync batch document %d: %w", i+1, err)
}
defer syncResp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(syncResp.Body)
return fmt.Errorf("sync batch document %d failed with status %d: %s", i+1, syncResp.StatusCode, string(body))
}
}
return nil
}
// RetrieverOptions for Knowledge retrieval.
type RetrieverOptions struct {
Count int `json:"count,omitempty"` // Max documents to retrieve.
MetricType string `json:"metric_type,omitempty"` // Similarity metric (e.g., "L2", "IP").
}
// Retrieve implements the Retriever.Retrieve method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// count := 3
// metricTypeStr := "L2"
// if req.Options != nil {
// ropt, ok := req.Options.(*RetrieverOptions)
// if !ok {
// return nil, fmt.Errorf("graphrag.Retrieve options have type %T, want %T", req.Options, &RetrieverOptions{})
// }
// if ropt.Count > 0 {
// count = ropt.Count
// }
// if ropt.MetricType != "" {
// metricTypeStr = ropt.MetricType
// }
// }
// Format query for retrieval.
queryText := fmt.Sprintf("Search for: %s", req.Query.Content)
username := "Alice" // Default, override if metadata available.
if req.Query.Metadata != nil {
if uname, ok := req.Query.Metadata["username"].(string); ok {
username = uname
}
}
// Prepare request for chat completions endpoint.
url := fmt.Sprintf("%s/api/v1/chat/completions", ds.client.BaseURL)
chatReq := struct {
ConvUID string `json:"conv_uid"`
UserInput string `json:"user_input"`
UserName string `json:"user_name"`
ChatMode string `json:"chat_mode"`
AppCode string `json:"app_code"`
Temperature float32 `json:"temperature"`
MaxNewTokens int `json:"max_new_tokens"`
SelectParam string `json:"select_param"`
ModelName string `json:"model_name"`
Incremental bool `json:"incremental"`
SysCode string `json:"sys_code"`
PromptCode string `json:"prompt_code"`
ExtInfo map[string]interface{} `json:"ext_info"`
}{
ConvUID: "",
UserInput: queryText,
UserName: username,
ChatMode: "",
AppCode: "",
Temperature: 0.5,
MaxNewTokens: 4000,
SelectParam: "",
ModelName: ds.modelName,
Incremental: false,
SysCode: "",
PromptCode: "",
ExtInfo: map[string]interface{}{
"space_id": ds.spaceID,
//"k": count,
},
}
body, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("marshal chat request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("create chat request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send chat request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("chat completion failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var chatResp struct {
Success bool `json:"success"`
Data struct {
Answer []struct {
Content string `json:"content"`
DocID int `json:"doc_id"`
Score float64 `json:"score"`
Metadata map[string]interface{} `json:"metadata_map"`
} `json:"answer"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, fmt.Errorf("decode chat response: %w", err)
}
var docs []*ai.Document
for _, doc := range chatResp.Data.Answer {
metadata := doc.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
// Ensure metadata includes user_id and username.
if _, ok := metadata["user_id"]; !ok {
metadata["user_id"] = "user123"
}
if _, ok := metadata["username"]; !ok {
metadata["username"] = username
}
aiDoc := ai.DocumentFromText(doc.Content, metadata)
docs = append(docs, aiDoc)
}
return &ai.RetrieverResponse{
Documents: docs,
}, nil
}
/*
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/space/add' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"id": 0,
"name": "string",
"vector_type": "string",
"domain_type": "Normal",
"desc": "string",
"owner": "string",
"space_id": 0
}'
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/111111/document/add' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"doc_name": "string",
"doc_id": 0,
"doc_type": "string",
"doc_token": "string",
"content": "string",
"source": "string",
"labels": "string",
"questions": [
"string"
]
}'
curl -X 'POST' \
'http://54.92.111.204:5670/knowledge/1111/document/sync_batch' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '[
{
"doc_id": 0,
"space_id": "string",
"model_name": "string",
"chunk_parameters": {
"chunk_strategy": "string",
"text_splitter": "string",
"splitter_type": "user_define",
"chunk_size": 512,
"chunk_overlap": 50,
"separator": "\n",
"enable_merge": true
}
}
]'
curl -X 'POST' \
'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"conv_uid": "",
"user_input": "",
"user_name": "string",
"chat_mode": "",
"app_code": "",
"temperature": 0.5,
"max_new_tokens": 4000,
"select_param": "string",
"model_name": "string",
"incremental": false,
"sys_code": "string",
"prompt_code": "string",
"ext_info": {}
}'
*/
package knowledge
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
// Client 知识库客户端
type Client struct {
BaseURL string // 基础URL,例如 "http://54.92.111.204:5670"
}
// SpaceRequest 创建空间的请求结构体
type SpaceRequest struct {
ID int `json:"id"`
Name string `json:"name"`
VectorType string `json:"vector_type"`
DomainType string `json:"domain_type"`
Desc string `json:"desc"`
Owner string `json:"owner"`
SpaceID int `json:"space_id"`
}
// DocumentRequest 添加文档的请求结构体
type DocumentRequest struct {
DocName string `json:"doc_name"`
DocID int `json:"doc_id"`
DocType string `json:"doc_type"`
DocToken string `json:"doc_token"`
Content string `json:"content"`
Source string `json:"source"`
Labels string `json:"labels"`
Questions []string `json:"questions"`
}
// ChunkParameters 分片参数
type ChunkParameters struct {
ChunkStrategy string `json:"chunk_strategy"`
TextSplitter string `json:"text_splitter"`
SplitterType string `json:"splitter_type"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`
Separator string `json:"separator"`
EnableMerge bool `json:"enable_merge"`
}
// SyncBatchRequest 同步批处理的请求结构体
type SyncBatchRequest struct {
DocID int `json:"doc_id"`
SpaceID string `json:"space_id"`
ModelName string `json:"model_name"`
ChunkParameters ChunkParameters `json:"chunk_parameters"`
}
// NewClient 创建新的客户端实例
func NewClient(ip string, port int) *Client {
return &Client{
BaseURL: fmt.Sprintf("http://%s:%d", ip, port),
}
}
// AddSpace 创建知识空间
func (c *Client) AddSpace(req SpaceRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/space/add", c.BaseURL)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// AddDocument 添加文档
func (c *Client) AddDocument(spaceID string, req DocumentRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/add", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
// SyncBatchDocument 同步批处理文档
func (c *Client) SyncBatchDocument(spaceID string, req []SyncBatchRequest) (*http.Response, error) {
url := fmt.Sprintf("%s/knowledge/%s/document/sync_batch", c.BaseURL, spaceID)
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
return resp, nil
}
SILICONFLOW_API_KEY=sk-ogigzbogipwhtkwvwnoeiovjdalkotopnpkwkxlvsvjsmyms docker compose up -d
curl 'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'Accept-Language: zh-CN,zh;q=0.9' \
-H 'Connection: keep-alive' \
-H 'Content-Type: application/json' \
-b 'sb-eowcghempsnzqalhkois-auth-token=base64-eyJhY2Nlc3NfdG9rZW4iOiJleUpoYkdjaU9pSklVekkxTmlJc0ltdHBaQ0k2SWpkemVpdEtXQzlCU0ZaSFFUWnBhVE1pTENKMGVYQWlPaUpLVjFRaWZRLmV5SnBjM01pT2lKb2RIUndjem92TDJWdmQyTm5hR1Z0Y0hOdWVuRmhiR2hyYjJsekxuTjFjR0ZpWVhObExtTnZMMkYxZEdndmRqRWlMQ0p6ZFdJaU9pSXpaamRoWVdRMU1TMHpNVGd3TFRSbVlXRXRPV0ZsTnkwMlpXRm1PV001Tm1JNFpHVWlMQ0poZFdRaU9pSmhkWFJvWlc1MGFXTmhkR1ZrSWl3aVpYaHdJam94TnpRNE16TXpOREV5TENKcFlYUWlPakUzTkRnek1qazRNVElzSW1WdFlXbHNJam9pYkdsM1pXbGZkMkZrWlVCcFkyeHZkV1F1WTI5dElpd2ljR2h2Ym1VaU9pSWlMQ0poY0hCZmJXVjBZV1JoZEdFaU9uc2ljSEp2ZG1sa1pYSWlPaUpsYldGcGJDSXNJbkJ5YjNacFpHVnljeUk2V3lKbGJXRnBiQ0pkZlN3aWRYTmxjbDl0WlhSaFpHRjBZU0k2ZXlKbGJXRnBiQ0k2SW14cGQyVnBYM2RoWkdWQWFXTnNiM1ZrTG1OdmJTSXNJbVZ0WVdsc1gzWmxjbWxtYVdWa0lqcDBjblZsTENKd2FHOXVaVjkyWlhKcFptbGxaQ0k2Wm1Gc2MyVXNJbk4xWWlJNklqTm1OMkZoWkRVeExUTXhPREF0TkdaaFlTMDVZV1UzTFRabFlXWTVZemsyWWpoa1pTSjlMQ0p5YjJ4bElqb2lZWFYwYUdWdWRHbGpZWFJsWkNJc0ltRmhiQ0k2SW1GaGJERWlMQ0poYlhJaU9sdDdJbTFsZEdodlpDSTZJbkJoYzNOM2IzSmtJaXdpZEdsdFpYTjBZVzF3SWpveE56UTJOVE0zTXpNNWZWMHNJbk5sYzNOcGIyNWZhV1FpT2lKa1lUTmhOelpqTmkweVlXVXdMVFEyWVdNdFlUaGxOUzFtWVRCallUTXpObUl5TlRZaUxDSnBjMTloYm05dWVXMXZkWE1pT21aaGJITmxmUS52MEhKdFRHN3EzSnc4QkRqdlFrWm9pOTVKcnNVZGNUZi1FWjBvc2d6OEk0IiwidG9rZW5fdHlwZSI6ImJlYXJlciIsImV4cGlyZXNfaW4iOjM2MDAsImV4cGlyZXNfYXQiOjE3NDgzMzM0MTIsInJlZnJlc2hfdG9rZW4iOiJscmpzdGl5MnM0a2giLCJ1c2VyIjp7ImlkIjoiM2Y3YWFkNTEtMzE4MC00ZmFhLTlhZTctNmVhZjljOTZiOGRlIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsInJvbGUiOiJhdXRoZW50aWNhdGVkIiwiZW1haWwiOiJsaXdlaV93YWRlQGljbG91ZC5jb20iLCJlbWFpbF9jb25maXJtZWRfYXQiOiIyMDI1LTAyLTEyVDExOjA0OjQwLjIxNzkzOVoiLCJwaG9uZSI6IiIsImNvbmZpcm1hdGlvbl9zZW50X2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC41Mzk2MjhaIiwiY29uZmlybWVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDo0MC4yMTc5MzlaIiwibGFzdF9zaWduX2luX2F0IjoiMjAyNS0wNS0wNlQxMzo0NToyMC45MDA0MTRaIiwiYXBwX21ldGFkYXRhIjp7InByb3ZpZGVyIjoiZW1haWwiLCJwcm92aWRlcnMiOlsiZW1haWwiXX0sInVzZXJfbWV0YWRhdGEiOnsiZW1haWwiOiJsaXdlaV93YWRlQGljbG91ZC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwicGhvbmVfdmVyaWZpZWQiOmZhbHNlLCJzdWIiOiIzZjdhYWQ1MS0zMTgwLTRmYWEtOWFlNy02ZWFmOWM5NmI4ZGUifSwiaWRlbnRpdGllcyI6W3siaWRlbnRpdHlfaWQiOiI2MjFiYTUxZi0yYzYzLTQxOWMtOWI2OS0zYzUzYTc5NDlhMzkiLCJpZCI6IjNmN2FhZDUxLTMxODAtNGZhYS05YWU3LTZlYWY5Yzk2YjhkZSIsInVzZXJfaWQiOiIzZjdhYWQ1MS0zMTgwLTRmYWEtOWFlNy02ZWFmOWM5NmI4ZGUiLCJpZGVudGl0eV9kYXRhIjp7ImVtYWlsIjoibGl3ZWlfd2FkZUBpY2xvdWQuY29tIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsInBob25lX3ZlcmlmaWVkIjpmYWxzZSwic3ViIjoiM2Y3YWFkNTEtMzE4MC00ZmFhLTlhZTctNmVhZjljOTZiOGRlIn0sInByb3ZpZGVyIjoiZW1haWwiLCJsYXN0X3NpZ25faW5fYXQiOiIyMDI1LTAyLTEyVDExOjA0OjAwLjUxMTAwOVoiLCJjcmVhdGVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC41MTExNDRaIiwidXBkYXRlZF9hdCI6IjIwMjUtMDItMTJUMTE6MDQ6MDAuNTExMTQ0WiIsImVtYWlsIjoibGl3ZWlfd2FkZUBpY2xvdWQuY29tIn1dLCJjcmVhdGVkX2F0IjoiMjAyNS0wMi0xMlQxMTowNDowMC40NDYwMzdaIiwidXBkYXRlZF9hdCI6IjIwMjUtMDUtMjdUMDc6MTA6MTIuMjM0MDUyWiIsImlzX2Fub255bW91cyI6ZmFsc2V9fQ' \
-H 'Origin: http://54.92.111.204:5670' \
-H 'Referer: http://54.92.111.204:5670/chat?scene=chat_knowledge&id=a6997f46-3f9b-11f0-b9d7-36eb2f648a81&knowledge_id=bbbbbb' \
-H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36' \
-H 'accept: text/event-stream' \
-H 'user-id;' \
--data-raw '{"chat_mode":"chat_knowledge","model_name":"Qwen/Qwen2.5-Coder-32B-Instruct","user_input":"你好","app_code":"chat_knowledge","temperature":0.6,"max_new_tokens":4000,"select_param":"bbbbbb","conv_uid":"a6997f46-3f9b-11f0-b9d7-36eb2f648a81"}'
curl -X 'POST' \
'http://54.92.111.204:5670/api/v1/chat/completions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"conv_uid": "",
"user_input": "",
"user_name": "string",
"chat_mode": "",
"app_code": "",
"temperature": 0.5,
"max_new_tokens": 4000,
"select_param": "string",
"model_name": "string",
"incremental": false,
"sys_code": "string",
"prompt_code": "string",
"ext_info": {}
}'
...@@ -273,6 +273,7 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer { ...@@ -273,6 +273,7 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
func Retriever(g *genkit.Genkit, collection string) ai.Retriever { func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection) return genkit.LookupRetriever(g, provider, collection)
} }
/* /*
更新 删除 很少用到; 更新 删除 很少用到;
*/ */
......
...@@ -11,169 +11,169 @@ import ( ...@@ -11,169 +11,169 @@ import (
) )
var ( var (
connString = flag.String("dbconn", "", "database connection string") connString = flag.String("dbconn", "", "database connection string")
) )
type QA struct { type QA struct {
ID int64 // 主键 ID int64 // 主键
CreatedAt time.Time // 创建时间 CreatedAt time.Time // 创建时间
FromID *string // 可空的 from_id FromID *string // 可空的 from_id
From *string // 可空的 from From *string // 可空的 from
Question *string // 可空的问题 Question *string // 可空的问题
Answer *string // 可空的答案 Answer *string // 可空的答案
Summary *string // 可空的摘要 Summary *string // 可空的摘要
To *string // 可空的 to To *string // 可空的 to
ToID *string // 可空的 to_id ToID *string // 可空的 to_id
} }
// QAStore 定义 DAO 接口 // QAStore 定义 DAO 接口
type QAStore interface { type QAStore interface {
// GetLatestQA 从 qa_latest_from_id 视图读取指定 from_id 的最新记录 // GetLatestQA 从 qa_latest_from_id 视图读取指定 from_id 的最新记录
GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error)
// WriteQA 插入或更新 qa 表记录 // WriteQA 插入或更新 qa 表记录
WriteQA(ctx context.Context, qa QA) (int64, error) WriteQA(ctx context.Context, qa QA) (int64, error)
} }
// qaStore 是 QAStore 接口的实现 // qaStore 是 QAStore 接口的实现
type qaStore struct { type qaStore struct {
db *sql.DB db *sql.DB
} }
// NewQAStore 创建新的 QAStore 实例 // NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore { func NewQAStore(db *sql.DB) QAStore {
return &qaStore{db: db} return &qaStore{db: db}
} }
// 初始化数据库连接并返回 QAStore // 初始化数据库连接并返回 QAStore
func InitQAStore() (QAStore, error) { func InitQAStore() (QAStore, error) {
// Supabase 提供的连接字符串 // Supabase 提供的连接字符串
connString := "postgresql://postgres.awcfgdodiuqnlsobcivq:P99IU9NEoDRPsBfb@aws-0-ap-southeast-1.pooler.supabase.com:5432/postgres" connString := "postgresql://postgres.awcfgdodiuqnlsobcivq:P99IU9NEoDRPsBfb@aws-0-ap-southeast-1.pooler.supabase.com:5432/postgres"
// 打开数据库连接 // 打开数据库连接
db, err := sql.Open("postgres", connString) db, err := sql.Open("postgres", connString)
if err != nil { if err != nil {
return nil, fmt.Errorf("open database: %w", err) return nil, fmt.Errorf("open database: %w", err)
} }
// 测试数据库连接 // 测试数据库连接
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
db.Close() db.Close()
return nil, fmt.Errorf("ping database: %w", err) return nil, fmt.Errorf("ping database: %w", err)
} }
// 返回 QAStore 实例 // 返回 QAStore 实例
return NewQAStore(db), nil return NewQAStore(db), nil
} }
func (s *qaStore) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) { func (s *qaStore) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) {
query := ` query := `
SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id
FROM qa_latest_from_id FROM qa_latest_from_id
WHERE from_id = $1 OR (from_id IS NULL AND $1 IS NULL)` WHERE from_id = $1 OR (from_id IS NULL AND $1 IS NULL)`
args := []interface{}{fromID} args := []interface{}{fromID}
if fromID == nil { if fromID == nil {
args = []interface{}{nil} args = []interface{}{nil}
} }
rows, err := s.db.QueryContext(ctx, query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, fmt.Errorf("query qa_latest_from_id: %w", err) return nil, fmt.Errorf("query qa_latest_from_id: %w", err)
} }
defer rows.Close() defer rows.Close()
var results []QA var results []QA
for rows.Next() { for rows.Next() {
var qa QA var qa QA
var question, answer, summary, from, to, fromIDVal, toIDVal sql.NullString var question, answer, summary, from, to, fromIDVal, toIDVal sql.NullString
if err := rows.Scan(&qa.ID, &qa.CreatedAt, &question, &answer, &summary, &from, &to, &fromIDVal, &toIDVal); err != nil { if err := rows.Scan(&qa.ID, &qa.CreatedAt, &question, &answer, &summary, &from, &to, &fromIDVal, &toIDVal); err != nil {
return nil, fmt.Errorf("scan row: %w", err) return nil, fmt.Errorf("scan row: %w", err)
} }
if question.Valid { if question.Valid {
qa.Question = &question.String qa.Question = &question.String
} }
if answer.Valid { if answer.Valid {
qa.Answer = &answer.String qa.Answer = &answer.String
} }
if summary.Valid { if summary.Valid {
qa.Summary = &summary.String qa.Summary = &summary.String
} }
if from.Valid { if from.Valid {
qa.From = &from.String qa.From = &from.String
} }
if to.Valid { if to.Valid {
qa.To = &to.String qa.To = &to.String
} }
if fromIDVal.Valid { if fromIDVal.Valid {
qa.FromID = &fromIDVal.String qa.FromID = &fromIDVal.String
} }
if toIDVal.Valid { if toIDVal.Valid {
qa.ToID = &toIDVal.String qa.ToID = &toIDVal.String
} }
results = append(results, qa) results = append(results, qa)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration: %w", err) return nil, fmt.Errorf("row iteration: %w", err)
} }
return results, nil return results, nil
} }
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) { func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
if qa.ID != 0 { if qa.ID != 0 {
// 更新记录 // 更新记录
query := ` query := `
UPDATE qa UPDATE qa
SET question = $1, answer = $2, summary = $3, "from" = $4, "to" = $5, from_id = $6, to_id = $7 SET question = $1, answer = $2, summary = $3, "from" = $4, "to" = $5, from_id = $6, to_id = $7
WHERE id = $8 WHERE id = $8
RETURNING id` RETURNING id`
var updatedID int64 var updatedID int64
err := s.db.QueryRowContext(ctx, query, err := s.db.QueryRowContext(ctx, query,
derefString(qa.Question), derefString(qa.Question),
derefString(qa.Answer), derefString(qa.Answer),
derefString(qa.Summary), derefString(qa.Summary),
derefString(qa.From), derefString(qa.From),
derefString(qa.To), derefString(qa.To),
derefString(qa.FromID), derefString(qa.FromID),
derefString(qa.ToID), derefString(qa.ToID),
qa.ID, qa.ID,
).Scan(&updatedID) ).Scan(&updatedID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, fmt.Errorf("no record found with id %d", qa.ID) return 0, fmt.Errorf("no record found with id %d", qa.ID)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("update qa: %w", err) return 0, fmt.Errorf("update qa: %w", err)
} }
return updatedID, nil return updatedID, nil
} }
// 插入新记录 // 插入新记录
query := ` query := `
INSERT INTO qa (question, answer, summary, "from", "to", from_id, to_id) INSERT INTO qa (question, answer, summary, "from", "to", from_id, to_id)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id` RETURNING id`
var newID int64 var newID int64
err := s.db.QueryRowContext(ctx, query, err := s.db.QueryRowContext(ctx, query,
derefString(qa.Question), derefString(qa.Question),
derefString(qa.Answer), derefString(qa.Answer),
derefString(qa.Summary), derefString(qa.Summary),
derefString(qa.From), derefString(qa.From),
derefString(qa.To), derefString(qa.To),
derefString(qa.FromID), derefString(qa.FromID),
derefString(qa.ToID), derefString(qa.ToID),
).Scan(&newID) ).Scan(&newID)
if err != nil { if err != nil {
return 0, fmt.Errorf("insert qa: %w", err) return 0, fmt.Errorf("insert qa: %w", err)
} }
return newID, nil return newID, nil
} }
// 辅助函数:处理指针类型的空值 // 辅助函数:处理指针类型的空值
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }
func derefString(p *string) interface{} { func derefString(p *string) interface{} {
if p == nil { if p == nil {
return nil return nil
} }
return *p return *p
} }
\ No newline at end of file
...@@ -13,180 +13,180 @@ import ( ...@@ -13,180 +13,180 @@ import (
// TestQAStore 测试 QAStore 的读写功能 // TestQAStore 测试 QAStore 的读写功能
func TestQAStore(t *testing.T) { func TestQAStore(t *testing.T) {
// 初始化测试数据库连接 // 初始化测试数据库连接
store, err := InitQAStore() store, err := InitQAStore()
require.NoError(t, err, "failed to initialize QAStore") require.NoError(t, err, "failed to initialize QAStore")
defer store.(*qaStore).db.Close() defer store.(*qaStore).db.Close()
ctx := context.Background() ctx := context.Background()
// 清理测试数据(可选,确保测试环境干净) // 清理测试数据(可选,确保测试环境干净)
// cleanup := func() { // cleanup := func() {
// _, err := store.(*qaStore).db.ExecContext(ctx, "TRUNCATE TABLE qa RESTART IDENTITY") // _, err := store.(*qaStore).db.ExecContext(ctx, "TRUNCATE TABLE qa RESTART IDENTITY")
// require.NoError(t, err, "failed to truncate qa table") // require.NoError(t, err, "failed to truncate qa table")
// } // }
// cleanup() // cleanup()
t.Run("WriteQA_Insert", func(t *testing.T) { t.Run("WriteQA_Insert", func(t *testing.T) {
// 测试插入新记录 // 测试插入新记录
qa := QA{ qa := QA{
Question: stringPtr("What is Go?"), Question: stringPtr("What is Go?"),
Answer: stringPtr("A programming language"), Answer: stringPtr("A programming language"),
Summary: stringPtr("Go introduction"), Summary: stringPtr("Go introduction"),
From: stringPtr("Alice"), From: stringPtr("Alice"),
FromID: stringPtr("user123"), FromID: stringPtr("user123"),
To: stringPtr("Bob"), To: stringPtr("Bob"),
ToID: stringPtr("user456"), ToID: stringPtr("user456"),
} }
id, err := store.WriteQA(ctx, qa) id, err := store.WriteQA(ctx, qa)
require.NoError(t, err, "failed to insert QA") require.NoError(t, err, "failed to insert QA")
assert.NotZero(t, id, "inserted ID should not be zero") assert.NotZero(t, id, "inserted ID should not be zero")
// 验证数据库中的记录 // 验证数据库中的记录
var stored QA var stored QA
row := store.(*qaStore).db.QueryRowContext(ctx, ` row := store.(*qaStore).db.QueryRowContext(ctx, `
SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id
FROM qa FROM qa
WHERE id = $1`, id) WHERE id = $1`, id)
var question, answer, summary, from, to, fromID, toID sql.NullString var question, answer, summary, from, to, fromID, toID sql.NullString
err = row.Scan(&stored.ID, &stored.CreatedAt, &question, &answer, &summary, &from, &to, &fromID, &toID) err = row.Scan(&stored.ID, &stored.CreatedAt, &question, &answer, &summary, &from, &to, &fromID, &toID)
require.NoError(t, err, "failed to query inserted QA") require.NoError(t, err, "failed to query inserted QA")
// 设置结构体字段 // 设置结构体字段
if question.Valid { if question.Valid {
stored.Question = &question.String stored.Question = &question.String
} }
if answer.Valid { if answer.Valid {
stored.Answer = &answer.String stored.Answer = &answer.String
} }
if summary.Valid { if summary.Valid {
stored.Summary = &summary.String stored.Summary = &summary.String
} }
if from.Valid { if from.Valid {
stored.From = &from.String stored.From = &from.String
} }
if to.Valid { if to.Valid {
stored.To = &to.String stored.To = &to.String
} }
if fromID.Valid { if fromID.Valid {
stored.FromID = &fromID.String stored.FromID = &fromID.String
} }
if toID.Valid { if toID.Valid {
stored.ToID = &toID.String stored.ToID = &toID.String
} }
// 验证插入的数据 // 验证插入的数据
assert.Equal(t, qa.Question, stored.Question) assert.Equal(t, qa.Question, stored.Question)
assert.Equal(t, qa.Answer, stored.Answer) assert.Equal(t, qa.Answer, stored.Answer)
assert.Equal(t, qa.Summary, stored.Summary) assert.Equal(t, qa.Summary, stored.Summary)
assert.Equal(t, qa.From, stored.From) assert.Equal(t, qa.From, stored.From)
assert.Equal(t, qa.To, stored.To) assert.Equal(t, qa.To, stored.To)
assert.Equal(t, qa.FromID, stored.FromID) assert.Equal(t, qa.FromID, stored.FromID)
assert.Equal(t, qa.ToID, stored.ToID) assert.Equal(t, qa.ToID, stored.ToID)
}) })
t.Run("WriteQA_Update", func(t *testing.T) { t.Run("WriteQA_Update", func(t *testing.T) {
// 先插入一条记录 // 先插入一条记录
qa := QA{ qa := QA{
Question: stringPtr("What is Python?"), Question: stringPtr("What is Python?"),
Answer: stringPtr("Another language"), Answer: stringPtr("Another language"),
FromID: stringPtr("user789"), FromID: stringPtr("user789"),
} }
id, err := store.WriteQA(ctx, qa) id, err := store.WriteQA(ctx, qa)
require.NoError(t, err, "failed to insert QA for update test") require.NoError(t, err, "failed to insert QA for update test")
// 更新记录 // 更新记录
updatedQA := QA{ updatedQA := QA{
ID: id, ID: id,
Question: stringPtr("Updated: What is Python?"), Question: stringPtr("Updated: What is Python?"),
Answer: stringPtr("A versatile language"), Answer: stringPtr("A versatile language"),
Summary: stringPtr("Python summary"), Summary: stringPtr("Python summary"),
FromID: stringPtr("user789"), FromID: stringPtr("user789"),
} }
updatedID, err := store.WriteQA(ctx, updatedQA) updatedID, err := store.WriteQA(ctx, updatedQA)
require.NoError(t, err, "failed to update QA") require.NoError(t, err, "failed to update QA")
assert.Equal(t, id, updatedID, "updated ID should match inserted ID") assert.Equal(t, id, updatedID, "updated ID should match inserted ID")
// 验证更新后的记录 // 验证更新后的记录
var stored QA var stored QA
row := store.(*qaStore).db.QueryRowContext(ctx, ` row := store.(*qaStore).db.QueryRowContext(ctx, `
SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id
FROM qa FROM qa
WHERE id = $1`, id) WHERE id = $1`, id)
var question, answer, summary, from, to, fromID, toID sql.NullString var question, answer, summary, from, to, fromID, toID sql.NullString
err = row.Scan(&stored.ID, &stored.CreatedAt, &question, &answer, &summary, &from, &to, &fromID, &toID) err = row.Scan(&stored.ID, &stored.CreatedAt, &question, &answer, &summary, &from, &to, &fromID, &toID)
require.NoError(t, err, "failed to query updated QA") require.NoError(t, err, "failed to query updated QA")
if question.Valid { if question.Valid {
stored.Question = &question.String stored.Question = &question.String
} }
if answer.Valid { if answer.Valid {
stored.Answer = &answer.String stored.Answer = &answer.String
} }
if summary.Valid { if summary.Valid {
stored.Summary = &summary.String stored.Summary = &summary.String
} }
if fromID.Valid { if fromID.Valid {
stored.FromID = &fromID.String stored.FromID = &fromID.String
} }
assert.Equal(t, updatedQA.Question, stored.Question) assert.Equal(t, updatedQA.Question, stored.Question)
assert.Equal(t, updatedQA.Answer, stored.Answer) assert.Equal(t, updatedQA.Answer, stored.Answer)
assert.Equal(t, updatedQA.Summary, stored.Summary) assert.Equal(t, updatedQA.Summary, stored.Summary)
assert.Equal(t, updatedQA.FromID, stored.FromID) assert.Equal(t, updatedQA.FromID, stored.FromID)
}) })
t.Run("GetLatestQA", func(t *testing.T) { t.Run("GetLatestQA", func(t *testing.T) {
// 插入多条记录以测试视图 // 插入多条记录以测试视图
qa1 := QA{ qa1 := QA{
Question: stringPtr("First question"), Question: stringPtr("First question"),
Answer: stringPtr("First answer"), Answer: stringPtr("First answer"),
FromID: stringPtr("user123"), FromID: stringPtr("user123"),
CreatedAt: time.Now().Add(-2 * time.Hour), CreatedAt: time.Now().Add(-2 * time.Hour),
} }
_, err := store.WriteQA(ctx, qa1) _, err := store.WriteQA(ctx, qa1)
require.NoError(t, err, "failed to insert first QA") require.NoError(t, err, "failed to insert first QA")
qa2 := QA{ qa2 := QA{
Question: stringPtr("Second question"), Question: stringPtr("Second question"),
Answer: stringPtr("Second answer"), Answer: stringPtr("Second answer"),
FromID: stringPtr("user123"), FromID: stringPtr("user123"),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
_, err = store.WriteQA(ctx, qa2) _, err = store.WriteQA(ctx, qa2)
require.NoError(t, err, "failed to insert second QA") require.NoError(t, err, "failed to insert second QA")
// 查询最新记录 // 查询最新记录
fromID := stringPtr("user123") fromID := stringPtr("user123")
results, err := store.GetLatestQA(ctx, fromID) results, err := store.GetLatestQA(ctx, fromID)
require.NoError(t, err, "failed to get latest QA") require.NoError(t, err, "failed to get latest QA")
require.Len(t, results, 1, "should return exactly one record for from_id") require.Len(t, results, 1, "should return exactly one record for from_id")
// 验证返回的记录是最新的一条 // 验证返回的记录是最新的一条
assert.Equal(t, qa2.Question, results[0].Question) assert.Equal(t, qa2.Question, results[0].Question)
assert.Equal(t, qa2.Answer, results[0].Answer) assert.Equal(t, qa2.Answer, results[0].Answer)
assert.Equal(t, qa2.FromID, results[0].FromID) assert.Equal(t, qa2.FromID, results[0].FromID)
}) })
t.Run("GetLatestQA_NullFromID", func(t *testing.T) { t.Run("GetLatestQA_NullFromID", func(t *testing.T) {
// 插入一条 from_id 为 NULL 的记录 // 插入一条 from_id 为 NULL 的记录
qa := QA{ qa := QA{
Question: stringPtr("Null from_id question"), Question: stringPtr("Null from_id question"),
Answer: stringPtr("Null from_id answer"), Answer: stringPtr("Null from_id answer"),
FromID: nil, FromID: nil,
} }
_, err := store.WriteQA(ctx, qa) _, err := store.WriteQA(ctx, qa)
require.NoError(t, err, "failed to insert QA with null from_id") require.NoError(t, err, "failed to insert QA with null from_id")
// 查询 from_id 为 NULL 的记录 // 查询 from_id 为 NULL 的记录
results, err := store.GetLatestQA(ctx, nil) results, err := store.GetLatestQA(ctx, nil)
require.NoError(t, err, "failed to get latest QA for null from_id") require.NoError(t, err, "failed to get latest QA for null from_id")
assert.NotEmpty(t, results, "should return records for null from_id") assert.NotEmpty(t, results, "should return records for null from_id")
assert.Nil(t, results[0].FromID, "from_id should be nil") assert.Nil(t, results[0].FromID, "from_id should be nil")
assert.Equal(t, qa.Question, results[0].Question) assert.Equal(t, qa.Question, results[0].Question)
}) })
// 清理测试数据 // 清理测试数据
// cleanup() // cleanup()
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment