Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
A
agentchat
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
李伟@五瓣科技
agentchat
Commits
c3342c58
Commit
c3342c58
authored
May 30, 2025
by
Wade
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add rate limit
parent
b28247b8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
902 additions
and
839 deletions
+902
-839
go.mod
go.mod
+1
-0
go.sum
go.sum
+2
-0
main.go
main.go
+32
-20
deepseek.go
plugins/deepseek/deepseek.go
+100
-108
graph.go
plugins/graphrag/graph.go
+93
-93
milvus.go
plugins/milvus/milvus.go
+339
-341
milvus_test.go
plugins/milvus/milvus_test.go
+141
-142
qa.go
qa.go
+135
-135
rate.go
rate.go
+59
-0
No files found.
go.mod
View file @
c3342c58
...
@@ -80,6 +80,7 @@ require (
...
@@ -80,6 +80,7 @@ require (
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genai v1.5.0 // indirect
google.golang.org/genai v1.5.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
google.golang.org/grpc v1.72.0 // indirect
google.golang.org/grpc v1.72.0 // indirect
...
...
go.sum
View file @
c3342c58
...
@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
...
@@ -416,6 +416,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
...
...
main.go
View file @
c3342c58
...
@@ -4,11 +4,13 @@ import (
...
@@ -4,11 +4,13 @@ import (
"context"
"context"
"fmt"
"fmt"
"log"
"log"
"net/http"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/genkit"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/firebase/genkit/go/plugins/server"
)
)
func
main
()
{
func
main
()
{
...
@@ -16,7 +18,7 @@ func main() {
...
@@ -16,7 +18,7 @@ func main() {
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
ds
:=
deepseek
.
DeepSeek
{
ds
:=
deepseek
.
DeepSeek
{
APIKey
:
"sk-9f70df871a7c4b8aa566a3c7a0603706"
,
APIKey
:
"sk-9f70df871a7c4b8aa566a3c7a0603706"
,
}
}
g
,
err
:=
genkit
.
Init
(
ctx
,
genkit
.
WithPlugins
(
&
ds
))
g
,
err
:=
genkit
.
Init
(
ctx
,
genkit
.
WithPlugins
(
&
ds
))
...
@@ -24,38 +26,48 @@ func main() {
...
@@ -24,38 +26,48 @@ func main() {
log
.
Fatal
(
err
)
log
.
Fatal
(
err
)
}
}
m
:=
ds
.
DefineModel
(
g
,
m
:=
ds
.
DefineModel
(
g
,
deepseek
.
ModelDefinition
{
deepseek
.
ModelDefinition
{
Name
:
"deepseek-chat"
,
// Choose an appropriate model
Name
:
"deepseek-chat"
,
// Choose an appropriate model
Type
:
"chat"
,
// Must be chat for tool support
Type
:
"chat"
,
// Must be chat for tool support
},
},
nil
)
nil
)
// Define a simple flow that generates jokes about a given topic
// Define a simple flow that generates jokes about a given topic
//
genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) {
genkit
.
DefineFlow
(
g
,
"jokesFlow"
,
func
(
ctx
context
.
Context
,
input
string
)
(
string
,
error
)
{
resp
,
err
:=
genkit
.
Generate
(
ctx
,
g
,
resp
,
err
:=
genkit
.
Generate
(
ctx
,
g
,
ai
.
WithModel
(
m
),
ai
.
WithModel
(
m
),
ai
.
WithPrompt
(
`Tell silly short jokes about apple`
))
ai
.
WithPrompt
(
`Tell silly short jokes about apple`
))
if
err
!=
nil
{
if
err
!=
nil
{
fmt
.
Println
(
err
.
Error
())
fmt
.
Println
(
err
.
Error
())
return
return
""
,
err
}
}
fmt
.
Println
(
"resp.Text()"
,
resp
.
Text
())
// if err != nil {
fmt
.
Println
(
"resp.Text()"
,
resp
.
Text
())
// return "", err
// }
// text := resp.Text()
if
err
!=
nil
{
// return text, nil
return
""
,
err
// })
}
//<-ctx.Done()
text
:=
resp
.
Text
()
}
return
text
,
nil
})
// 配置限速器:每秒 10 次请求,突发容量 20,最大并发 5
rl
:=
NewRateLimiter
(
10
,
20
,
5
)
// 创建 Genkit HTTP 处理器
mux
:=
http
.
NewServeMux
()
for
_
,
a
:=
range
genkit
.
ListFlows
(
g
)
{
handler
:=
rl
.
Middleware
(
genkit
.
Handler
(
a
))
mux
.
Handle
(
"POST /"
+
a
.
Name
(),
handler
)
}
// 启动服务器,监听
log
.
Printf
(
"Server starting on 0.0.0.0:3400"
)
if
err
:=
server
.
Start
(
ctx
,
"0.0.0.0:3400"
,
mux
);
err
!=
nil
{
log
.
Fatalf
(
"Server failed: %v"
,
err
)
}
}
plugins/deepseek/deepseek.go
View file @
c3342c58
...
@@ -16,7 +16,7 @@ import (
...
@@ -16,7 +16,7 @@ import (
const
provider
=
"deepseek"
const
provider
=
"deepseek"
var
(
var
(
mediaSupportedModels
=
[]
string
{
deepseek
.
DeepSeekChat
,
deepseek
.
DeepSeekCoder
,
deepseek
.
DeepSeekReasoner
}
mediaSupportedModels
=
[]
string
{
deepseek
.
DeepSeekChat
,
deepseek
.
DeepSeekCoder
,
deepseek
.
DeepSeekReasoner
}
// toolSupportedModels = []string{
// toolSupportedModels = []string{
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
// "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
...
@@ -34,154 +34,148 @@ var (
...
@@ -34,154 +34,148 @@ var (
}
}
)
)
// DeepSeek holds configuration for the plugin.
// DeepSeek holds configuration for the plugin.
type
DeepSeek
struct
{
type
DeepSeek
struct
{
APIKey
string
// DeepSeek API key
APIKey
string
// DeepSeek API key
//ServerAddress string
//ServerAddress string
mu
sync
.
Mutex
// Mutex to control access.
mu
sync
.
Mutex
// Mutex to control access.
initted
bool
// Whether the plugin has been initialized.
initted
bool
// Whether the plugin has been initialized.
}
}
// Name returns the provider name.
// Name returns the provider name.
func
(
d
DeepSeek
)
Name
()
string
{
func
(
d
DeepSeek
)
Name
()
string
{
return
provider
return
provider
}
}
// ModelDefinition represents a model with its name and type.
// ModelDefinition represents a model with its name and type.
type
ModelDefinition
struct
{
type
ModelDefinition
struct
{
Name
string
Name
string
Type
string
Type
string
}
}
// // DefineModel defines a DeepSeek model in Genkit.
// // DefineModel defines a DeepSeek model in Genkit.
func
(
d
*
DeepSeek
)
DefineModel
(
g
*
genkit
.
Genkit
,
model
ModelDefinition
,
info
*
ai
.
ModelInfo
)
ai
.
Model
{
func
(
d
*
DeepSeek
)
DefineModel
(
g
*
genkit
.
Genkit
,
model
ModelDefinition
,
info
*
ai
.
ModelInfo
)
ai
.
Model
{
d
.
mu
.
Lock
()
d
.
mu
.
Lock
()
defer
d
.
mu
.
Unlock
()
defer
d
.
mu
.
Unlock
()
if
!
d
.
initted
{
if
!
d
.
initted
{
panic
(
"deepseek.Init not called"
)
panic
(
"deepseek.Init not called"
)
}
}
// Define model info, supporting multiturn and system role.
// Define model info, supporting multiturn and system role.
mi
:=
ai
.
ModelInfo
{
mi
:=
ai
.
ModelInfo
{
Label
:
model
.
Name
,
Label
:
model
.
Name
,
Supports
:
&
ai
.
ModelSupports
{
Supports
:
&
ai
.
ModelSupports
{
Multiturn
:
true
,
Multiturn
:
true
,
SystemRole
:
true
,
SystemRole
:
true
,
Media
:
false
,
// DeepSeek API primarily supports text.
Media
:
false
,
// DeepSeek API primarily supports text.
Tools
:
false
,
// Tools not yet supported in this implementation.
Tools
:
false
,
// Tools not yet supported in this implementation.
},
},
Versions
:
[]
string
{},
Versions
:
[]
string
{},
}
}
if
info
!=
nil
{
if
info
!=
nil
{
mi
=
*
info
mi
=
*
info
}
}
meta
:=
&
ai
.
ModelInfo
{
meta
:=
&
ai
.
ModelInfo
{
// Label: "DeepSeek - " + model.Name,
// Label: "DeepSeek - " + model.Name,
Label
:
model
.
Name
,
Label
:
model
.
Name
,
Supports
:
mi
.
Supports
,
Supports
:
mi
.
Supports
,
Versions
:
[]
string
{},
Versions
:
[]
string
{},
}
}
gen
:=
&
generator
{
model
:
model
,
apiKey
:
d
.
APIKey
}
gen
:=
&
generator
{
model
:
model
,
apiKey
:
d
.
APIKey
}
return
genkit
.
DefineModel
(
g
,
provider
,
model
.
Name
,
meta
,
gen
.
generate
)
return
genkit
.
DefineModel
(
g
,
provider
,
model
.
Name
,
meta
,
gen
.
generate
)
}
}
// Init initializes the DeepSeek plugin.
// Init initializes the DeepSeek plugin.
func
(
d
*
DeepSeek
)
Init
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
)
error
{
func
(
d
*
DeepSeek
)
Init
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
)
error
{
d
.
mu
.
Lock
()
d
.
mu
.
Lock
()
defer
d
.
mu
.
Unlock
()
defer
d
.
mu
.
Unlock
()
if
d
.
initted
{
if
d
.
initted
{
panic
(
"deepseek.Init already called"
)
panic
(
"deepseek.Init already called"
)
}
}
if
d
==
nil
||
d
.
APIKey
==
""
{
if
d
==
nil
||
d
.
APIKey
==
""
{
return
fmt
.
Errorf
(
"deepseek: need APIKey"
)
return
fmt
.
Errorf
(
"deepseek: need APIKey"
)
}
}
d
.
initted
=
true
d
.
initted
=
true
return
nil
return
nil
}
}
// generator handles model generation.
// generator handles model generation.
type
generator
struct
{
type
generator
struct
{
model
ModelDefinition
model
ModelDefinition
apiKey
string
apiKey
string
}
}
// generate implements the Genkit model generation interface.
// generate implements the Genkit model generation interface.
func
(
g
*
generator
)
generate
(
ctx
context
.
Context
,
input
*
ai
.
ModelRequest
,
cb
func
(
context
.
Context
,
*
ai
.
ModelResponseChunk
)
error
)
(
*
ai
.
ModelResponse
,
error
)
{
func
(
g
*
generator
)
generate
(
ctx
context
.
Context
,
input
*
ai
.
ModelRequest
,
cb
func
(
context
.
Context
,
*
ai
.
ModelResponseChunk
)
error
)
(
*
ai
.
ModelResponse
,
error
)
{
// stream := cb != nil
// stream := cb != nil
if
len
(
input
.
Messages
)
==
0
{
if
len
(
input
.
Messages
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"prompt or messages required"
)
return
nil
,
fmt
.
Errorf
(
"prompt or messages required"
)
}
}
// Set up the Deepseek client
// Set up the Deepseek client
// Initialize DeepSeek client.
// Initialize DeepSeek client.
client
:=
deepseek
.
NewClient
(
g
.
apiKey
)
client
:=
deepseek
.
NewClient
(
g
.
apiKey
)
// Create a chat completion request
// Create a chat completion request
request
:=
&
deepseek
.
ChatCompletionRequest
{
request
:=
&
deepseek
.
ChatCompletionRequest
{
Model
:
g
.
model
.
Name
,
Model
:
g
.
model
.
Name
,
}
}
for
_
,
msg
:=
range
input
.
Messages
{
for
_
,
msg
:=
range
input
.
Messages
{
role
,
ok
:=
roleMapping
[
msg
.
Role
]
role
,
ok
:=
roleMapping
[
msg
.
Role
]
if
!
ok
{
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"unsupported role: %s"
,
msg
.
Role
)
return
nil
,
fmt
.
Errorf
(
"unsupported role: %s"
,
msg
.
Role
)
}
content
:=
concatMessageParts
(
msg
.
Content
)
request
.
Messages
=
append
(
request
.
Messages
,
deepseek
.
ChatCompletionMessage
{
Role
:
role
,
Content
:
content
,
})
}
}
content
:=
concatMessageParts
(
msg
.
Content
)
request
.
Messages
=
append
(
request
.
Messages
,
deepseek
.
ChatCompletionMessage
{
Role
:
role
,
Content
:
content
,
})
}
// Send the request and handle the response
// Send the request and handle the response
response
,
err
:=
client
.
CreateChatCompletion
(
ctx
,
request
)
response
,
err
:=
client
.
CreateChatCompletion
(
ctx
,
request
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatalf
(
"error: %v"
,
err
)
log
.
Fatalf
(
"error: %v"
,
err
)
}
}
// Print the response
fmt
.
Println
(
"Response:"
,
response
.
Choices
[
0
]
.
Message
.
Content
)
// Create a final response with the merged chunks
finalResponse
:=
&
ai
.
ModelResponse
{
Request
:
input
,
FinishReason
:
ai
.
FinishReason
(
"stop"
),
Message
:
&
ai
.
Message
{
Role
:
ai
.
RoleModel
,
},
}
for
_
,
chunk
:=
range
response
.
Choices
{
// Print the response
p
:=
ai
.
Part
{
fmt
.
Println
(
"Response:"
,
response
.
Choices
[
0
]
.
Message
.
Content
)
Text
:
chunk
.
Message
.
Content
,
Kind
:
ai
.
PartKind
(
chunk
.
Index
),
// Create a final response with the merged chunks
}
finalResponse
:=
&
ai
.
ModelResponse
{
Request
:
input
,
FinishReason
:
ai
.
FinishReason
(
"stop"
),
Message
:
&
ai
.
Message
{
Role
:
ai
.
RoleModel
,
},
}
finalResponse
.
Message
.
Content
=
append
(
finalResponse
.
Message
.
Content
,
&
p
)
for
_
,
chunk
:=
range
response
.
Choices
{
p
:=
ai
.
Part
{
Text
:
chunk
.
Message
.
Content
,
Kind
:
ai
.
PartKind
(
chunk
.
Index
),
}
}
return
finalResponse
,
nil
// Return the final merged response
finalResponse
.
Message
.
Content
=
append
(
finalResponse
.
Message
.
Content
,
&
p
)
}
return
finalResponse
,
nil
// Return the final merged response
}
}
// concatMessageParts concatenates message parts into a single string.
// concatMessageParts concatenates message parts into a single string.
func
concatMessageParts
(
parts
[]
*
ai
.
Part
)
string
{
func
concatMessageParts
(
parts
[]
*
ai
.
Part
)
string
{
var
sb
strings
.
Builder
var
sb
strings
.
Builder
for
_
,
part
:=
range
parts
{
for
_
,
part
:=
range
parts
{
if
part
.
IsText
()
{
if
part
.
IsText
()
{
sb
.
WriteString
(
part
.
Text
)
sb
.
WriteString
(
part
.
Text
)
}
}
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
// Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
}
}
return
sb
.
String
()
return
sb
.
String
()
}
}
/*
/*
// Choice represents a completion choice generated by the model.
// Choice represents a completion choice generated by the model.
...
@@ -205,5 +199,3 @@ type Part struct {
...
@@ -205,5 +199,3 @@ type Part struct {
}
}
*/
*/
plugins/graphrag/graph.go
View file @
c3342c58
...
@@ -71,129 +71,129 @@ import (
...
@@ -71,129 +71,129 @@ import (
// Client 知识库客户端
// Client 知识库客户端
type
Client
struct
{
type
Client
struct
{
BaseURL
string
// 基础URL,例如 "http://54.92.111.204:5670"
BaseURL
string
// 基础URL,例如 "http://54.92.111.204:5670"
}
}
// SpaceRequest 创建空间的请求结构体
// SpaceRequest 创建空间的请求结构体
type
SpaceRequest
struct
{
type
SpaceRequest
struct
{
ID
int
`json:"id"`
ID
int
`json:"id"`
Name
string
`json:"name"`
Name
string
`json:"name"`
VectorType
string
`json:"vector_type"`
VectorType
string
`json:"vector_type"`
DomainType
string
`json:"domain_type"`
DomainType
string
`json:"domain_type"`
Desc
string
`json:"desc"`
Desc
string
`json:"desc"`
Owner
string
`json:"owner"`
Owner
string
`json:"owner"`
SpaceID
int
`json:"space_id"`
SpaceID
int
`json:"space_id"`
}
}
// DocumentRequest 添加文档的请求结构体
// DocumentRequest 添加文档的请求结构体
type
DocumentRequest
struct
{
type
DocumentRequest
struct
{
DocName
string
`json:"doc_name"`
DocName
string
`json:"doc_name"`
DocID
int
`json:"doc_id"`
DocID
int
`json:"doc_id"`
DocType
string
`json:"doc_type"`
DocType
string
`json:"doc_type"`
DocToken
string
`json:"doc_token"`
DocToken
string
`json:"doc_token"`
Content
string
`json:"content"`
Content
string
`json:"content"`
Source
string
`json:"source"`
Source
string
`json:"source"`
Labels
string
`json:"labels"`
Labels
string
`json:"labels"`
Questions
[]
string
`json:"questions"`
Questions
[]
string
`json:"questions"`
}
}
// ChunkParameters 分片参数
// ChunkParameters 分片参数
type
ChunkParameters
struct
{
type
ChunkParameters
struct
{
ChunkStrategy
string
`json:"chunk_strategy"`
ChunkStrategy
string
`json:"chunk_strategy"`
TextSplitter
string
`json:"text_splitter"`
TextSplitter
string
`json:"text_splitter"`
SplitterType
string
`json:"splitter_type"`
SplitterType
string
`json:"splitter_type"`
ChunkSize
int
`json:"chunk_size"`
ChunkSize
int
`json:"chunk_size"`
ChunkOverlap
int
`json:"chunk_overlap"`
ChunkOverlap
int
`json:"chunk_overlap"`
Separator
string
`json:"separator"`
Separator
string
`json:"separator"`
EnableMerge
bool
`json:"enable_merge"`
EnableMerge
bool
`json:"enable_merge"`
}
}
// SyncBatchRequest 同步批处理的请求结构体
// SyncBatchRequest 同步批处理的请求结构体
type
SyncBatchRequest
struct
{
type
SyncBatchRequest
struct
{
DocID
int
`json:"doc_id"`
DocID
int
`json:"doc_id"`
SpaceID
string
`json:"space_id"`
SpaceID
string
`json:"space_id"`
ModelName
string
`json:"model_name"`
ModelName
string
`json:"model_name"`
ChunkParameters
ChunkParameters
`json:"chunk_parameters"`
ChunkParameters
ChunkParameters
`json:"chunk_parameters"`
}
}
// NewClient 创建新的客户端实例
// NewClient 创建新的客户端实例
func
NewClient
(
ip
string
,
port
int
)
*
Client
{
func
NewClient
(
ip
string
,
port
int
)
*
Client
{
return
&
Client
{
return
&
Client
{
BaseURL
:
fmt
.
Sprintf
(
"http://%s:%d"
,
ip
,
port
),
BaseURL
:
fmt
.
Sprintf
(
"http://%s:%d"
,
ip
,
port
),
}
}
}
}
// AddSpace 创建知识空间
// AddSpace 创建知识空间
func
(
c
*
Client
)
AddSpace
(
req
SpaceRequest
)
(
*
http
.
Response
,
error
)
{
func
(
c
*
Client
)
AddSpace
(
req
SpaceRequest
)
(
*
http
.
Response
,
error
)
{
url
:=
fmt
.
Sprintf
(
"%s/knowledge/space/add"
,
c
.
BaseURL
)
url
:=
fmt
.
Sprintf
(
"%s/knowledge/space/add"
,
c
.
BaseURL
)
body
,
err
:=
json
.
Marshal
(
req
)
body
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
}
}
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
}
}
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
client
:=
&
http
.
Client
{}
client
:=
&
http
.
Client
{}
resp
,
err
:=
client
.
Do
(
httpReq
)
resp
,
err
:=
client
.
Do
(
httpReq
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
}
}
return
resp
,
nil
return
resp
,
nil
}
}
// AddDocument 添加文档
// AddDocument 添加文档
func
(
c
*
Client
)
AddDocument
(
spaceID
string
,
req
DocumentRequest
)
(
*
http
.
Response
,
error
)
{
func
(
c
*
Client
)
AddDocument
(
spaceID
string
,
req
DocumentRequest
)
(
*
http
.
Response
,
error
)
{
url
:=
fmt
.
Sprintf
(
"%s/knowledge/%s/document/add"
,
c
.
BaseURL
,
spaceID
)
url
:=
fmt
.
Sprintf
(
"%s/knowledge/%s/document/add"
,
c
.
BaseURL
,
spaceID
)
body
,
err
:=
json
.
Marshal
(
req
)
body
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
}
}
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
}
}
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
client
:=
&
http
.
Client
{}
client
:=
&
http
.
Client
{}
resp
,
err
:=
client
.
Do
(
httpReq
)
resp
,
err
:=
client
.
Do
(
httpReq
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
}
}
return
resp
,
nil
return
resp
,
nil
}
}
// SyncBatchDocument 同步批处理文档
// SyncBatchDocument 同步批处理文档
func
(
c
*
Client
)
SyncBatchDocument
(
spaceID
string
,
req
[]
SyncBatchRequest
)
(
*
http
.
Response
,
error
)
{
func
(
c
*
Client
)
SyncBatchDocument
(
spaceID
string
,
req
[]
SyncBatchRequest
)
(
*
http
.
Response
,
error
)
{
url
:=
fmt
.
Sprintf
(
"%s/knowledge/%s/document/sync_batch"
,
c
.
BaseURL
,
spaceID
)
url
:=
fmt
.
Sprintf
(
"%s/knowledge/%s/document/sync_batch"
,
c
.
BaseURL
,
spaceID
)
body
,
err
:=
json
.
Marshal
(
req
)
body
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to marshal request: %w"
,
err
)
}
}
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
httpReq
,
err
:=
http
.
NewRequest
(
"POST"
,
url
,
bytes
.
NewBuffer
(
body
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
}
}
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
client
:=
&
http
.
Client
{}
client
:=
&
http
.
Client
{}
resp
,
err
:=
client
.
Do
(
httpReq
)
resp
,
err
:=
client
.
Do
(
httpReq
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to send request: %w"
,
err
)
}
}
return
resp
,
nil
return
resp
,
nil
}
}
plugins/milvus/milvus.go
View file @
c3342c58
...
@@ -36,393 +36,391 @@ const provider = "milvus"
...
@@ -36,393 +36,391 @@ const provider = "milvus"
// Field names for Milvus schema.
// Field names for Milvus schema.
const
(
const
(
idField
=
"id"
idField
=
"id"
vectorField
=
"vector"
vectorField
=
"vector"
textField
=
"text"
textField
=
"text"
metadataField
=
"metadata"
metadataField
=
"metadata"
)
)
// Milvus holds configuration for the plugin.
// Milvus holds configuration for the plugin.
type
Milvus
struct
{
type
Milvus
struct
{
// Milvus server address (host:port, e.g., "localhost:19530").
// Milvus server address (host:port, e.g., "localhost:19530").
// Defaults to MILVUS_ADDRESS environment variable.
// Defaults to MILVUS_ADDRESS environment variable.
Addr
string
Addr
string
// Username for authentication.
// Username for authentication.
// Defaults to MILVUS_USERNAME.
// Defaults to MILVUS_USERNAME.
Username
string
Username
string
// Password for authentication.
// Password for authentication.
// Defaults to MILVUS_PASSWORD.
// Defaults to MILVUS_PASSWORD.
Password
string
Password
string
// Token for authentication (alternative to username/password).
// Token for authentication (alternative to username/password).
// Defaults to MILVUS_TOKEN.
// Defaults to MILVUS_TOKEN.
Token
string
Token
string
client
client
.
Client
// Milvus client.
client
client
.
Client
// Milvus client.
mu
sync
.
Mutex
// Mutex to control access.
mu
sync
.
Mutex
// Mutex to control access.
initted
bool
// Whether the plugin has been initialized.
initted
bool
// Whether the plugin has been initialized.
}
}
// Name returns the plugin name.
// Name returns the plugin name.
func
(
m
*
Milvus
)
Name
()
string
{
func
(
m
*
Milvus
)
Name
()
string
{
return
provider
return
provider
}
}
// Init initializes the Milvus plugin.
// Init initializes the Milvus plugin.
func
(
m
*
Milvus
)
Init
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
)
(
err
error
)
{
func
(
m
*
Milvus
)
Init
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
)
(
err
error
)
{
if
m
==
nil
{
if
m
==
nil
{
m
=
&
Milvus
{}
m
=
&
Milvus
{}
}
}
m
.
mu
.
Lock
()
m
.
mu
.
Lock
()
defer
m
.
mu
.
Unlock
()
defer
m
.
mu
.
Unlock
()
defer
func
()
{
defer
func
()
{
if
err
!=
nil
{
if
err
!=
nil
{
err
=
fmt
.
Errorf
(
"milvus.Init: %w"
,
err
)
err
=
fmt
.
Errorf
(
"milvus.Init: %w"
,
err
)
}
}
}()
}()
if
m
.
initted
{
if
m
.
initted
{
return
errors
.
New
(
"plugin already initialized"
)
return
errors
.
New
(
"plugin already initialized"
)
}
}
// Load configuration.
// Load configuration.
addr
:=
m
.
Addr
addr
:=
m
.
Addr
if
addr
==
""
{
if
addr
==
""
{
addr
=
os
.
Getenv
(
"MILVUS_ADDRESS"
)
addr
=
os
.
Getenv
(
"MILVUS_ADDRESS"
)
}
}
if
addr
==
""
{
if
addr
==
""
{
return
errors
.
New
(
"milvus address required"
)
return
errors
.
New
(
"milvus address required"
)
}
}
username
:=
m
.
Username
username
:=
m
.
Username
if
username
==
""
{
if
username
==
""
{
username
=
os
.
Getenv
(
"MILVUS_USERNAME"
)
username
=
os
.
Getenv
(
"MILVUS_USERNAME"
)
}
}
password
:=
m
.
Password
password
:=
m
.
Password
if
password
==
""
{
if
password
==
""
{
password
=
os
.
Getenv
(
"MILVUS_PASSWORD"
)
password
=
os
.
Getenv
(
"MILVUS_PASSWORD"
)
}
}
token
:=
m
.
Token
token
:=
m
.
Token
if
token
==
""
{
if
token
==
""
{
token
=
os
.
Getenv
(
"MILVUS_TOKEN"
)
token
=
os
.
Getenv
(
"MILVUS_TOKEN"
)
}
}
// Initialize Milvus client.
// Initialize Milvus client.
config
:=
client
.
Config
{
config
:=
client
.
Config
{
Address
:
addr
,
Address
:
addr
,
Username
:
username
,
Username
:
username
,
Password
:
password
,
Password
:
password
,
APIKey
:
token
,
APIKey
:
token
,
}
}
client
,
err
:=
client
.
NewClient
(
ctx
,
config
)
client
,
err
:=
client
.
NewClient
(
ctx
,
config
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to initialize Milvus client: %v"
,
err
)
return
fmt
.
Errorf
(
"failed to initialize Milvus client: %v"
,
err
)
}
}
m
.
client
=
client
m
.
client
=
client
m
.
initted
=
true
m
.
initted
=
true
return
nil
return
nil
}
}
// CollectionConfig holds configuration for an indexer/retriever pair.
// CollectionConfig holds configuration for an indexer/retriever pair.
type
CollectionConfig
struct
{
type
CollectionConfig
struct
{
// Milvus collection name. Must not be empty.
// Milvus collection name. Must not be empty.
Collection
string
Collection
string
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
// Embedding vector dimension (e.g., 1536 for text-embedding-ada-002).
Dimension
int
Dimension
int
// Embedder for generating vectors.
// Embedder for generating vectors.
Embedder
ai
.
Embedder
Embedder
ai
.
Embedder
// Embedder options.
// Embedder options.
EmbedderOptions
any
EmbedderOptions
any
}
}
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
// DefineIndexerAndRetriever defines an Indexer and Retriever for a Milvus collection.
func
DefineIndexerAndRetriever
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
,
cfg
CollectionConfig
)
(
ai
.
Indexer
,
ai
.
Retriever
,
error
)
{
func
DefineIndexerAndRetriever
(
ctx
context
.
Context
,
g
*
genkit
.
Genkit
,
cfg
CollectionConfig
)
(
ai
.
Indexer
,
ai
.
Retriever
,
error
)
{
if
cfg
.
Embedder
==
nil
{
if
cfg
.
Embedder
==
nil
{
return
nil
,
nil
,
errors
.
New
(
"milvus: Embedder required"
)
return
nil
,
nil
,
errors
.
New
(
"milvus: Embedder required"
)
}
}
if
cfg
.
Collection
==
""
{
if
cfg
.
Collection
==
""
{
return
nil
,
nil
,
errors
.
New
(
"milvus: collection name required"
)
return
nil
,
nil
,
errors
.
New
(
"milvus: collection name required"
)
}
}
if
cfg
.
Dimension
<=
0
{
if
cfg
.
Dimension
<=
0
{
return
nil
,
nil
,
errors
.
New
(
"milvus: dimension must be positive"
)
return
nil
,
nil
,
errors
.
New
(
"milvus: dimension must be positive"
)
}
}
m
:=
genkit
.
LookupPlugin
(
g
,
provider
)
m
:=
genkit
.
LookupPlugin
(
g
,
provider
)
if
m
==
nil
{
if
m
==
nil
{
return
nil
,
nil
,
errors
.
New
(
"milvus plugin not found; did you call genkit.Init with the milvus plugin?"
)
return
nil
,
nil
,
errors
.
New
(
"milvus plugin not found; did you call genkit.Init with the milvus plugin?"
)
}
}
milvus
:=
m
.
(
*
Milvus
)
milvus
:=
m
.
(
*
Milvus
)
ds
,
err
:=
milvus
.
newDocStore
(
ctx
,
&
cfg
)
ds
,
err
:=
milvus
.
newDocStore
(
ctx
,
&
cfg
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
indexer
:=
genkit
.
DefineIndexer
(
g
,
provider
,
cfg
.
Collection
,
ds
.
Index
)
indexer
:=
genkit
.
DefineIndexer
(
g
,
provider
,
cfg
.
Collection
,
ds
.
Index
)
retriever
:=
genkit
.
DefineRetriever
(
g
,
provider
,
cfg
.
Collection
,
ds
.
Retrieve
)
retriever
:=
genkit
.
DefineRetriever
(
g
,
provider
,
cfg
.
Collection
,
ds
.
Retrieve
)
return
indexer
,
retriever
,
nil
return
indexer
,
retriever
,
nil
}
}
// docStore defines an Indexer and a Retriever.
// docStore defines an Indexer and a Retriever.
type
docStore
struct
{
type
docStore
struct
{
client
client
.
Client
client
client
.
Client
collection
string
collection
string
dimension
int
dimension
int
embedder
ai
.
Embedder
embedder
ai
.
Embedder
embedderOptions
map
[
string
]
interface
{}
embedderOptions
map
[
string
]
interface
{}
}
}
// newDocStore creates a docStore.
// newDocStore creates a docStore.
func
(
m
*
Milvus
)
newDocStore
(
ctx
context
.
Context
,
cfg
*
CollectionConfig
)
(
*
docStore
,
error
)
{
func
(
m
*
Milvus
)
newDocStore
(
ctx
context
.
Context
,
cfg
*
CollectionConfig
)
(
*
docStore
,
error
)
{
if
m
.
client
==
nil
{
if
m
.
client
==
nil
{
return
nil
,
errors
.
New
(
"milvus.Init not called"
)
return
nil
,
errors
.
New
(
"milvus.Init not called"
)
}
}
// Check/create collection.
// Check/create collection.
exists
,
err
:=
m
.
client
.
HasCollection
(
ctx
,
cfg
.
Collection
)
exists
,
err
:=
m
.
client
.
HasCollection
(
ctx
,
cfg
.
Collection
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to check collection %q: %v"
,
cfg
.
Collection
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to check collection %q: %v"
,
cfg
.
Collection
,
err
)
}
}
if
!
exists
{
if
!
exists
{
// Define schema.
// Define schema.
schema
:=
&
entity
.
Schema
{
schema
:=
&
entity
.
Schema
{
CollectionName
:
cfg
.
Collection
,
CollectionName
:
cfg
.
Collection
,
Fields
:
[]
*
entity
.
Field
{
Fields
:
[]
*
entity
.
Field
{
{
{
Name
:
idField
,
Name
:
idField
,
DataType
:
entity
.
FieldTypeInt64
,
DataType
:
entity
.
FieldTypeInt64
,
PrimaryKey
:
true
,
PrimaryKey
:
true
,
AutoID
:
true
,
AutoID
:
true
,
},
},
{
{
Name
:
vectorField
,
Name
:
vectorField
,
DataType
:
entity
.
FieldTypeFloatVector
,
DataType
:
entity
.
FieldTypeFloatVector
,
TypeParams
:
map
[
string
]
string
{
TypeParams
:
map
[
string
]
string
{
"dim"
:
fmt
.
Sprintf
(
"%d"
,
cfg
.
Dimension
),
"dim"
:
fmt
.
Sprintf
(
"%d"
,
cfg
.
Dimension
),
},
},
},
},
{
{
Name
:
textField
,
Name
:
textField
,
DataType
:
entity
.
FieldTypeVarChar
,
DataType
:
entity
.
FieldTypeVarChar
,
TypeParams
:
map
[
string
]
string
{
TypeParams
:
map
[
string
]
string
{
"max_length"
:
"65535"
,
"max_length"
:
"65535"
,
},
},
},
},
{
{
Name
:
metadataField
,
Name
:
metadataField
,
DataType
:
entity
.
FieldTypeJSON
,
DataType
:
entity
.
FieldTypeJSON
,
},
},
},
},
}
}
err
=
m
.
client
.
CreateCollection
(
ctx
,
schema
,
entity
.
DefaultShardNumber
)
err
=
m
.
client
.
CreateCollection
(
ctx
,
schema
,
entity
.
DefaultShardNumber
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create collection %q: %v"
,
cfg
.
Collection
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to create collection %q: %v"
,
cfg
.
Collection
,
err
)
}
}
// Create HNSW index.
// Create HNSW index.
index
,
err
:=
entity
.
NewIndexHNSW
(
index
,
err
:=
entity
.
NewIndexHNSW
(
entity
.
L2
,
entity
.
L2
,
8
,
// M
8
,
// M
96
,
// efConstruction
96
,
// efConstruction
)
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"entity.NewIndexHNSW: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"entity.NewIndexHNSW: %v"
,
err
)
}
}
err
=
m
.
client
.
CreateIndex
(
ctx
,
cfg
.
Collection
,
vectorField
,
index
,
false
)
err
=
m
.
client
.
CreateIndex
(
ctx
,
cfg
.
Collection
,
vectorField
,
index
,
false
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create index: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to create index: %v"
,
err
)
}
}
}
}
// Load collection.
// Load collection.
err
=
m
.
client
.
LoadCollection
(
ctx
,
cfg
.
Collection
,
false
)
err
=
m
.
client
.
LoadCollection
(
ctx
,
cfg
.
Collection
,
false
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to load collection %q: %v"
,
cfg
.
Collection
,
err
)
return
nil
,
fmt
.
Errorf
(
"failed to load collection %q: %v"
,
cfg
.
Collection
,
err
)
}
}
// Convert EmbedderOptions to map[string]interface{}.
// Convert EmbedderOptions to map[string]interface{}.
var
embedderOptions
map
[
string
]
interface
{}
var
embedderOptions
map
[
string
]
interface
{}
if
cfg
.
EmbedderOptions
!=
nil
{
if
cfg
.
EmbedderOptions
!=
nil
{
opts
,
ok
:=
cfg
.
EmbedderOptions
.
(
map
[
string
]
interface
{})
opts
,
ok
:=
cfg
.
EmbedderOptions
.
(
map
[
string
]
interface
{})
if
!
ok
{
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"EmbedderOptions must be a map[string]interface{}, got %T"
,
cfg
.
EmbedderOptions
)
return
nil
,
fmt
.
Errorf
(
"EmbedderOptions must be a map[string]interface{}, got %T"
,
cfg
.
EmbedderOptions
)
}
}
embedderOptions
=
opts
embedderOptions
=
opts
}
else
{
}
else
{
embedderOptions
=
make
(
map
[
string
]
interface
{})
embedderOptions
=
make
(
map
[
string
]
interface
{})
}
}
return
&
docStore
{
return
&
docStore
{
client
:
m
.
client
,
client
:
m
.
client
,
collection
:
cfg
.
Collection
,
collection
:
cfg
.
Collection
,
dimension
:
cfg
.
Dimension
,
dimension
:
cfg
.
Dimension
,
embedder
:
cfg
.
Embedder
,
embedder
:
cfg
.
Embedder
,
embedderOptions
:
embedderOptions
,
embedderOptions
:
embedderOptions
,
},
nil
},
nil
}
}
// Indexer returns the indexer for a collection.
// Indexer returns the indexer for a collection.
func
Indexer
(
g
*
genkit
.
Genkit
,
collection
string
)
ai
.
Indexer
{
func
Indexer
(
g
*
genkit
.
Genkit
,
collection
string
)
ai
.
Indexer
{
return
genkit
.
LookupIndexer
(
g
,
provider
,
collection
)
return
genkit
.
LookupIndexer
(
g
,
provider
,
collection
)
}
}
// Retriever returns the retriever for a collection.
// Retriever returns the retriever for a collection.
func
Retriever
(
g
*
genkit
.
Genkit
,
collection
string
)
ai
.
Retriever
{
func
Retriever
(
g
*
genkit
.
Genkit
,
collection
string
)
ai
.
Retriever
{
return
genkit
.
LookupRetriever
(
g
,
provider
,
collection
)
return
genkit
.
LookupRetriever
(
g
,
provider
,
collection
)
}
}
// Index implements the Indexer.Index method.
// Index implements the Indexer.Index method.
func
(
ds
*
docStore
)
Index
(
ctx
context
.
Context
,
req
*
ai
.
IndexerRequest
)
error
{
func
(
ds
*
docStore
)
Index
(
ctx
context
.
Context
,
req
*
ai
.
IndexerRequest
)
error
{
if
len
(
req
.
Documents
)
==
0
{
if
len
(
req
.
Documents
)
==
0
{
return
nil
return
nil
}
}
// Embed documents.
// Embed documents.
ereq
:=
&
ai
.
EmbedRequest
{
ereq
:=
&
ai
.
EmbedRequest
{
Input
:
req
.
Documents
,
Input
:
req
.
Documents
,
Options
:
ds
.
embedderOptions
,
Options
:
ds
.
embedderOptions
,
}
}
eres
,
err
:=
ds
.
embedder
.
Embed
(
ctx
,
ereq
)
eres
,
err
:=
ds
.
embedder
.
Embed
(
ctx
,
ereq
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"milvus index embedding failed: %w"
,
err
)
return
fmt
.
Errorf
(
"milvus index embedding failed: %w"
,
err
)
}
}
// Validate embedding count matches document count.
// Validate embedding count matches document count.
if
len
(
eres
.
Embeddings
)
!=
len
(
req
.
Documents
)
{
if
len
(
eres
.
Embeddings
)
!=
len
(
req
.
Documents
)
{
return
fmt
.
Errorf
(
"mismatch: got %d embeddings for %d documents"
,
len
(
eres
.
Embeddings
),
len
(
req
.
Documents
))
return
fmt
.
Errorf
(
"mismatch: got %d embeddings for %d documents"
,
len
(
eres
.
Embeddings
),
len
(
req
.
Documents
))
}
}
// Prepare row-based data.
// Prepare row-based data.
var
rows
[]
interface
{}
var
rows
[]
interface
{}
for
i
,
emb
:=
range
eres
.
Embeddings
{
for
i
,
emb
:=
range
eres
.
Embeddings
{
doc
:=
req
.
Documents
[
i
]
doc
:=
req
.
Documents
[
i
]
var
sb
strings
.
Builder
var
sb
strings
.
Builder
for
_
,
p
:=
range
doc
.
Content
{
for
_
,
p
:=
range
doc
.
Content
{
if
p
.
IsText
()
{
if
p
.
IsText
()
{
sb
.
WriteString
(
p
.
Text
)
sb
.
WriteString
(
p
.
Text
)
}
}
}
}
text
:=
sb
.
String
()
text
:=
sb
.
String
()
metadata
:=
doc
.
Metadata
metadata
:=
doc
.
Metadata
if
metadata
==
nil
{
if
metadata
==
nil
{
metadata
=
make
(
map
[
string
]
interface
{})
metadata
=
make
(
map
[
string
]
interface
{})
}
}
// Create row with explicit metadata field.
// Create row with explicit metadata field.
row
:=
make
(
map
[
string
]
interface
{})
row
:=
make
(
map
[
string
]
interface
{})
row
[
"vector"
]
=
emb
.
Embedding
// []float32
row
[
"vector"
]
=
emb
.
Embedding
// []float32
row
[
"text"
]
=
text
row
[
"text"
]
=
text
row
[
"metadata"
]
=
metadata
// Explicitly set metadata as JSON-compatible map
row
[
"metadata"
]
=
metadata
// Explicitly set metadata as JSON-compatible map
rows
=
append
(
rows
,
row
)
rows
=
append
(
rows
,
row
)
// Debug: Log row contents.
// Debug: Log row contents.
fmt
.
Printf
(
"Row %d: vector_len=%d, text=%q, metadata=%v
\n
"
,
i
,
len
(
emb
.
Embedding
),
text
,
metadata
)
fmt
.
Printf
(
"Row %d: vector_len=%d, text=%q, metadata=%v
\n
"
,
i
,
len
(
emb
.
Embedding
),
text
,
metadata
)
}
}
// Debug: Log total rows.
// Debug: Log total rows.
fmt
.
Printf
(
"Inserting %d rows into collection %q
\n
"
,
len
(
rows
),
ds
.
collection
)
fmt
.
Printf
(
"Inserting %d rows into collection %q
\n
"
,
len
(
rows
),
ds
.
collection
)
// Insert rows into Milvus.
// Insert rows into Milvus.
_
,
err
=
ds
.
client
.
InsertRows
(
ctx
,
ds
.
collection
,
""
,
rows
)
_
,
err
=
ds
.
client
.
InsertRows
(
ctx
,
ds
.
collection
,
""
,
rows
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"milvus insert rows failed: %w"
,
err
)
return
fmt
.
Errorf
(
"milvus insert rows failed: %w"
,
err
)
}
}
return
nil
return
nil
}
}
// RetrieverOptions for Milvus retrieval.
// RetrieverOptions for Milvus retrieval.
type
RetrieverOptions
struct
{
type
RetrieverOptions
struct
{
Count
int
`json:"count,omitempty"`
// Max documents to retrieve.
Count
int
`json:"count,omitempty"`
// Max documents to retrieve.
MetricType
string
`json:"metric_type,omitempty"`
// Similarity metric (e.g., "L2", "IP").
MetricType
string
`json:"metric_type,omitempty"`
// Similarity metric (e.g., "L2", "IP").
}
}
// Retrieve implements the Retriever.Retrieve method.
// Retrieve implements the Retriever.Retrieve method.
func
(
ds
*
docStore
)
Retrieve
(
ctx
context
.
Context
,
req
*
ai
.
RetrieverRequest
)
(
*
ai
.
RetrieverResponse
,
error
)
{
func
(
ds
*
docStore
)
Retrieve
(
ctx
context
.
Context
,
req
*
ai
.
RetrieverRequest
)
(
*
ai
.
RetrieverResponse
,
error
)
{
count
:=
3
// Default.
count
:=
3
// Default.
metricTypeStr
:=
"L2"
metricTypeStr
:=
"L2"
if
req
.
Options
!=
nil
{
if
req
.
Options
!=
nil
{
ropt
,
ok
:=
req
.
Options
.
(
*
RetrieverOptions
)
ropt
,
ok
:=
req
.
Options
.
(
*
RetrieverOptions
)
if
!
ok
{
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"milvus.Retrieve options have type %T, want %T"
,
req
.
Options
,
&
RetrieverOptions
{})
return
nil
,
fmt
.
Errorf
(
"milvus.Retrieve options have type %T, want %T"
,
req
.
Options
,
&
RetrieverOptions
{})
}
}
if
ropt
.
Count
>
0
{
if
ropt
.
Count
>
0
{
count
=
ropt
.
Count
count
=
ropt
.
Count
}
}
if
ropt
.
MetricType
!=
""
{
if
ropt
.
MetricType
!=
""
{
metricTypeStr
=
ropt
.
MetricType
metricTypeStr
=
ropt
.
MetricType
}
}
}
}
// Map string metric type to entity.MetricType.
// Map string metric type to entity.MetricType.
var
metricType
entity
.
MetricType
var
metricType
entity
.
MetricType
switch
metricTypeStr
{
switch
metricTypeStr
{
case
"L2"
:
case
"L2"
:
metricType
=
entity
.
L2
metricType
=
entity
.
L2
case
"IP"
:
case
"IP"
:
metricType
=
entity
.
IP
metricType
=
entity
.
IP
default
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported metric type: %s"
,
metricTypeStr
)
return
nil
,
fmt
.
Errorf
(
"unsupported metric type: %s"
,
metricTypeStr
)
}
}
// Embed query.
// Embed query.
ereq
:=
&
ai
.
EmbedRequest
{
ereq
:=
&
ai
.
EmbedRequest
{
Input
:
[]
*
ai
.
Document
{
req
.
Query
},
Input
:
[]
*
ai
.
Document
{
req
.
Query
},
Options
:
ds
.
embedderOptions
,
Options
:
ds
.
embedderOptions
,
}
}
eres
,
err
:=
ds
.
embedder
.
Embed
(
ctx
,
ereq
)
eres
,
err
:=
ds
.
embedder
.
Embed
(
ctx
,
ereq
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"milvus retrieve embedding failed: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"milvus retrieve embedding failed: %v"
,
err
)
}
}
if
len
(
eres
.
Embeddings
)
==
0
{
if
len
(
eres
.
Embeddings
)
==
0
{
return
nil
,
errors
.
New
(
"no embeddings generated for query"
)
return
nil
,
errors
.
New
(
"no embeddings generated for query"
)
}
}
queryVector
:=
entity
.
FloatVector
(
eres
.
Embeddings
[
0
]
.
Embedding
)
queryVector
:=
entity
.
FloatVector
(
eres
.
Embeddings
[
0
]
.
Embedding
)
// Create search parameters.
// Create search parameters.
searchParams
,
err
:=
entity
.
NewIndexHNSWSearchParam
(
64
)
// ef
searchParams
,
err
:=
entity
.
NewIndexHNSWSearchParam
(
64
)
// ef
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"NewIndexHNSWSearchParam failed: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"NewIndexHNSWSearchParam failed: %v"
,
err
)
}
}
// Perform search.
// Perform search.
results
,
err
:=
ds
.
client
.
Search
(
results
,
err
:=
ds
.
client
.
Search
(
ctx
,
ctx
,
ds
.
collection
,
ds
.
collection
,
[]
string
{},
// partitions
[]
string
{},
// partitions
""
,
// expr
""
,
// expr
[]
string
{
textField
,
metadataField
},
// output fields
[]
string
{
textField
,
metadataField
},
// output fields
[]
entity
.
Vector
{
queryVector
},
[]
entity
.
Vector
{
queryVector
},
vectorField
,
vectorField
,
metricType
,
metricType
,
count
,
count
,
searchParams
,
searchParams
,
)
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"milvus search failed: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"milvus search failed: %v"
,
err
)
}
}
// Process results.
// Process results.
var
docs
[]
*
ai
.
Document
var
docs
[]
*
ai
.
Document
for
_
,
result
:=
range
results
{
for
_
,
result
:=
range
results
{
for
i
:=
0
;
i
<
result
.
ResultCount
;
i
++
{
for
i
:=
0
;
i
<
result
.
ResultCount
;
i
++
{
textCol
:=
result
.
Fields
.
GetColumn
(
textField
)
textCol
:=
result
.
Fields
.
GetColumn
(
textField
)
text
,
err
:=
textCol
.
GetAsString
(
i
)
text
,
err
:=
textCol
.
GetAsString
(
i
)
if
err
!=
nil
{
if
err
!=
nil
{
continue
continue
}
}
var
metadata
map
[
string
]
interface
{}
var
metadata
map
[
string
]
interface
{}
doc
:=
ai
.
DocumentFromText
(
text
,
metadata
)
doc
:=
ai
.
DocumentFromText
(
text
,
metadata
)
docs
=
append
(
docs
,
doc
)
docs
=
append
(
docs
,
doc
)
}
}
}
}
return
&
ai
.
RetrieverResponse
{
return
&
ai
.
RetrieverResponse
{
Documents
:
docs
,
Documents
:
docs
,
},
nil
},
nil
}
}
plugins/milvus/milvus_test.go
View file @
c3342c58
...
@@ -33,156 +33,155 @@ import (
...
@@ -33,156 +33,155 @@ import (
type
MockEmbedder
struct
{}
type
MockEmbedder
struct
{}
func
(
m
*
MockEmbedder
)
Name
()
string
{
func
(
m
*
MockEmbedder
)
Name
()
string
{
return
"mock-embedder"
return
"mock-embedder"
}
}
func
(
m
*
MockEmbedder
)
Embed
(
ctx
context
.
Context
,
req
*
ai
.
EmbedRequest
)
(
*
ai
.
EmbedResponse
,
error
)
{
func
(
m
*
MockEmbedder
)
Embed
(
ctx
context
.
Context
,
req
*
ai
.
EmbedRequest
)
(
*
ai
.
EmbedResponse
,
error
)
{
resp
:=
&
ai
.
EmbedResponse
{}
resp
:=
&
ai
.
EmbedResponse
{}
for
range
req
.
Input
{
for
range
req
.
Input
{
// Generate a simple embedding (768-dimensional vector of ones)
// Generate a simple embedding (768-dimensional vector of ones)
embedding
:=
make
([]
float32
,
768
)
embedding
:=
make
([]
float32
,
768
)
for
i
:=
range
embedding
{
for
i
:=
range
embedding
{
embedding
[
i
]
=
1.0
embedding
[
i
]
=
1.0
}
}
resp
.
Embeddings
=
append
(
resp
.
Embeddings
,
&
ai
.
Embedding
{
Embedding
:
embedding
})
resp
.
Embeddings
=
append
(
resp
.
Embeddings
,
&
ai
.
Embedding
{
Embedding
:
embedding
})
}
}
return
resp
,
nil
return
resp
,
nil
}
}
// dropCollection cleans up a test collection.
// dropCollection cleans up a test collection.
func
dropCollection
(
ctx
context
.
Context
,
client
client
.
Client
,
collectionName
string
)
error
{
func
dropCollection
(
ctx
context
.
Context
,
client
client
.
Client
,
collectionName
string
)
error
{
exists
,
err
:=
client
.
HasCollection
(
ctx
,
collectionName
)
exists
,
err
:=
client
.
HasCollection
(
ctx
,
collectionName
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check collection: %w"
,
err
)
return
fmt
.
Errorf
(
"check collection: %w"
,
err
)
}
}
if
exists
{
if
exists
{
err
=
client
.
DropCollection
(
ctx
,
collectionName
)
err
=
client
.
DropCollection
(
ctx
,
collectionName
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"drop collection: %w"
,
err
)
return
fmt
.
Errorf
(
"drop collection: %w"
,
err
)
}
}
}
}
return
nil
return
nil
}
}
func
TestMilvusIntegration
(
t
*
testing
.
T
)
{
func
TestMilvusIntegration
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
// Initialize Milvus plugin
// Initialize Milvus plugin
ms
:=
Milvus
{
ms
:=
Milvus
{
Addr
:
"54.92.111.204:19530"
,
// Milvus gRPC endpoint
Addr
:
"54.92.111.204:19530"
,
// Milvus gRPC endpoint
}
}
// Initialize Genkit with Milvus plugin
// Initialize Genkit with Milvus plugin
g
,
err
:=
genkit
.
Init
(
ctx
,
genkit
.
WithPlugins
(
&
ms
))
g
,
err
:=
genkit
.
Init
(
ctx
,
genkit
.
WithPlugins
(
&
ms
))
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"genkit.Init failed: %v"
,
err
)
t
.
Fatalf
(
"genkit.Init failed: %v"
,
err
)
}
}
// Get the Milvus client for cleanup
// Get the Milvus client for cleanup
m
,
ok
:=
genkit
.
LookupPlugin
(
g
,
provider
)
.
(
*
Milvus
)
m
,
ok
:=
genkit
.
LookupPlugin
(
g
,
provider
)
.
(
*
Milvus
)
if
!
ok
{
if
!
ok
{
t
.
Fatalf
(
"Failed to lookup Milvus plugin"
)
t
.
Fatalf
(
"Failed to lookup Milvus plugin"
)
}
}
defer
m
.
client
.
Close
()
defer
m
.
client
.
Close
()
// Generate unique collection name
// Generate unique collection name
collectionName
:=
fmt
.
Sprintf
(
"test_collection_%d"
,
time
.
Now
()
.
UnixNano
())
collectionName
:=
fmt
.
Sprintf
(
"test_collection_%d"
,
time
.
Now
()
.
UnixNano
())
// Configure collection
// Configure collection
cfg
:=
CollectionConfig
{
cfg
:=
CollectionConfig
{
Collection
:
collectionName
,
Collection
:
collectionName
,
Dimension
:
768
,
// Match mock embedder dimension
Dimension
:
768
,
// Match mock embedder dimension
Embedder
:
&
MockEmbedder
{},
Embedder
:
&
MockEmbedder
{},
EmbedderOptions
:
map
[
string
]
interface
{}{},
// Explicitly set as map
EmbedderOptions
:
map
[
string
]
interface
{}{},
// Explicitly set as map
}
}
// Define indexer and retriever
// Define indexer and retriever
indexer
,
retriever
,
err
:=
DefineIndexerAndRetriever
(
ctx
,
g
,
cfg
)
indexer
,
retriever
,
err
:=
DefineIndexerAndRetriever
(
ctx
,
g
,
cfg
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"DefineIndexerAndRetriever failed: %v"
,
err
)
t
.
Fatalf
(
"DefineIndexerAndRetriever failed: %v"
,
err
)
}
}
// Clean up collection after test
// Clean up collection after test
// defer func() {
// defer func() {
// if err := dropCollection(ctx, m.client, collectionName); err != nil {
// if err := dropCollection(ctx, m.client, collectionName); err != nil {
// t.Errorf("Cleanup failed: %v", err)
// t.Errorf("Cleanup failed: %v", err)
// }
// }
// }()
// }()
t
.
Run
(
"Index and Retrieve"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Index and Retrieve"
,
func
(
t
*
testing
.
T
)
{
// Index documents
// Index documents
documents
:=
[]
*
ai
.
Document
{
documents
:=
[]
*
ai
.
Document
{
{
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)},
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)},
Metadata
:
map
[
string
]
interface
{}{
"id"
:
int64
(
1
),
"category"
:
"greeting"
},
Metadata
:
map
[
string
]
interface
{}{
"id"
:
int64
(
1
),
"category"
:
"greeting"
},
},
},
{
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"AI is amazing"
)},
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"AI is amazing"
)},
Metadata
:
map
[
string
]
interface
{}{
"id"
:
int64
(
2
),
"category"
:
"tech"
},
Metadata
:
map
[
string
]
interface
{}{
"id"
:
int64
(
2
),
"category"
:
"tech"
},
},
},
}
}
req
:=
&
ai
.
IndexerRequest
{
Documents
:
documents
}
req
:=
&
ai
.
IndexerRequest
{
Documents
:
documents
}
err
:=
indexer
.
Index
(
ctx
,
req
)
err
:=
indexer
.
Index
(
ctx
,
req
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"Index failed: %v"
,
err
)
t
.
Fatalf
(
"Index failed: %v"
,
err
)
}
}
// Wait briefly to ensure Milvus processes the index
// Wait briefly to ensure Milvus processes the index
time
.
Sleep
(
1
*
time
.
Second
)
time
.
Sleep
(
1
*
time
.
Second
)
// Retrieve documents
// Retrieve documents
queryReq
:=
&
ai
.
RetrieverRequest
{
queryReq
:=
&
ai
.
RetrieverRequest
{
Query
:
&
ai
.
Document
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)}},
Query
:
&
ai
.
Document
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)}},
Options
:
&
RetrieverOptions
{
Options
:
&
RetrieverOptions
{
Count
:
2
,
Count
:
2
,
MetricType
:
"L2"
,
MetricType
:
"L2"
,
},
},
}
}
resp
,
err
:=
retriever
.
Retrieve
(
ctx
,
queryReq
)
resp
,
err
:=
retriever
.
Retrieve
(
ctx
,
queryReq
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"Retrieve failed: %v"
,
err
)
t
.
Fatalf
(
"Retrieve failed: %v"
,
err
)
}
}
// Verify results
// Verify results
assert
.
NotNil
(
t
,
resp
,
"Response should not be nil"
)
assert
.
NotNil
(
t
,
resp
,
"Response should not be nil"
)
assert
.
NotEmpty
(
t
,
resp
.
Documents
,
"Should return at least one document"
)
assert
.
NotEmpty
(
t
,
resp
.
Documents
,
"Should return at least one document"
)
for
_
,
doc
:=
range
resp
.
Documents
{
for
_
,
doc
:=
range
resp
.
Documents
{
assert
.
NotEmpty
(
t
,
doc
.
Content
[
0
]
.
Text
,
"Document text should not be empty"
)
assert
.
NotEmpty
(
t
,
doc
.
Content
[
0
]
.
Text
,
"Document text should not be empty"
)
// Note: Mock embedder returns identical vectors, so results may not be exact
// Note: Mock embedder returns identical vectors, so results may not be exact
if
strings
.
Contains
(
doc
.
Content
[
0
]
.
Text
,
"Hello world"
)
||
strings
.
Contains
(
doc
.
Content
[
0
]
.
Text
,
"AI is amazing"
)
{
if
strings
.
Contains
(
doc
.
Content
[
0
]
.
Text
,
"Hello world"
)
||
strings
.
Contains
(
doc
.
Content
[
0
]
.
Text
,
"AI is amazing"
)
{
continue
continue
}
}
t
.
Errorf
(
"Unexpected document text: %s"
,
doc
.
Content
[
0
]
.
Text
)
t
.
Errorf
(
"Unexpected document text: %s"
,
doc
.
Content
[
0
]
.
Text
)
}
}
})
})
t
.
Run
(
"Empty Index"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Empty Index"
,
func
(
t
*
testing
.
T
)
{
req
:=
&
ai
.
IndexerRequest
{
Documents
:
[]
*
ai
.
Document
{}}
req
:=
&
ai
.
IndexerRequest
{
Documents
:
[]
*
ai
.
Document
{}}
err
:=
indexer
.
Index
(
ctx
,
req
)
err
:=
indexer
.
Index
(
ctx
,
req
)
assert
.
NoError
(
t
,
err
,
"Indexing empty documents should succeed"
)
assert
.
NoError
(
t
,
err
,
"Indexing empty documents should succeed"
)
})
})
t
.
Run
(
"Invalid Retrieve Options"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Invalid Retrieve Options"
,
func
(
t
*
testing
.
T
)
{
queryReq
:=
&
ai
.
RetrieverRequest
{
queryReq
:=
&
ai
.
RetrieverRequest
{
Query
:
&
ai
.
Document
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)}},
Query
:
&
ai
.
Document
{
Content
:
[]
*
ai
.
Part
{
ai
.
NewTextPart
(
"Hello world"
)}},
Options
:
&
RetrieverOptions
{
MetricType
:
"INVALID"
},
Options
:
&
RetrieverOptions
{
MetricType
:
"INVALID"
},
}
}
_
,
err
:=
retriever
.
Retrieve
(
ctx
,
queryReq
)
_
,
err
:=
retriever
.
Retrieve
(
ctx
,
queryReq
)
assert
.
Error
(
t
,
err
,
"Should fail with invalid metric type"
)
assert
.
Error
(
t
,
err
,
"Should fail with invalid metric type"
)
assert
.
Contains
(
t
,
err
.
Error
(),
"unsupported metric type"
)
assert
.
Contains
(
t
,
err
.
Error
(),
"unsupported metric type"
)
})
})
t
.
Run
(
"Invalid Embedder Options"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Invalid Embedder Options"
,
func
(
t
*
testing
.
T
)
{
// Test with invalid EmbedderOptions type
// Test with invalid EmbedderOptions type
invalidCfg
:=
CollectionConfig
{
invalidCfg
:=
CollectionConfig
{
Collection
:
collectionName
+
"_invalid"
,
Collection
:
collectionName
+
"_invalid"
,
Dimension
:
768
,
Dimension
:
768
,
Embedder
:
&
MockEmbedder
{},
Embedder
:
&
MockEmbedder
{},
EmbedderOptions
:
"not-a-map"
,
// Invalid type
EmbedderOptions
:
"not-a-map"
,
// Invalid type
}
}
_
,
_
,
err
:=
DefineIndexerAndRetriever
(
ctx
,
g
,
invalidCfg
)
_
,
_
,
err
:=
DefineIndexerAndRetriever
(
ctx
,
g
,
invalidCfg
)
assert
.
Error
(
t
,
err
,
"Should fail with invalid EmbedderOptions type"
)
assert
.
Error
(
t
,
err
,
"Should fail with invalid EmbedderOptions type"
)
assert
.
Contains
(
t
,
err
.
Error
(),
"EmbedderOptions must be a map[string]interface{}"
)
assert
.
Contains
(
t
,
err
.
Error
(),
"EmbedderOptions must be a map[string]interface{}"
)
})
})
}
}
qa.go
View file @
c3342c58
...
@@ -12,188 +12,188 @@ import (
...
@@ -12,188 +12,188 @@ import (
)
)
var
(
var
(
connString
=
flag
.
String
(
"dbconn"
,
""
,
"database connection string"
)
connString
=
flag
.
String
(
"dbconn"
,
""
,
"database connection string"
)
)
)
// QA 结构体表示 qa 表的记录
// QA 结构体表示 qa 表的记录
type
QA
struct
{
type
QA
struct
{
ID
int64
// 主键
ID
int64
// 主键
CreatedAt
time
.
Time
// 创建时间
CreatedAt
time
.
Time
// 创建时间
UserID
*
int64
// 可空的用户 ID
UserID
*
int64
// 可空的用户 ID
Username
*
string
// 可空的用户名
Username
*
string
// 可空的用户名
Question
*
string
// 可空的问题
Question
*
string
// 可空的问题
Answer
*
string
// 可空的答案
Answer
*
string
// 可空的答案
}
}
// QAStore 定义 DAO 接口
// QAStore 定义 DAO 接口
type
QAStore
interface
{
type
QAStore
interface
{
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
GetLatestQA
(
ctx
context
.
Context
,
userID
*
int64
)
([]
QA
,
error
)
GetLatestQA
(
ctx
context
.
Context
,
userID
*
int64
)
([]
QA
,
error
)
// WriteQA 插入或更新 qa 表记录
// WriteQA 插入或更新 qa 表记录
WriteQA
(
ctx
context
.
Context
,
qa
QA
)
(
int64
,
error
)
WriteQA
(
ctx
context
.
Context
,
qa
QA
)
(
int64
,
error
)
}
}
// qaStore 是 QAStore 接口的实现
// qaStore 是 QAStore 接口的实现
type
qaStore
struct
{
type
qaStore
struct
{
db
*
sql
.
DB
db
*
sql
.
DB
}
}
// NewQAStore 创建新的 QAStore 实例
// NewQAStore 创建新的 QAStore 实例
func
NewQAStore
(
db
*
sql
.
DB
)
QAStore
{
func
NewQAStore
(
db
*
sql
.
DB
)
QAStore
{
return
&
qaStore
{
db
:
db
}
return
&
qaStore
{
db
:
db
}
}
}
// GetLatestQA 从 latest_qa 视图读取数据
// GetLatestQA 从 latest_qa 视图读取数据
func
(
s
*
qaStore
)
GetLatestQA
(
ctx
context
.
Context
,
userID
*
int64
)
([]
QA
,
error
)
{
func
(
s
*
qaStore
)
GetLatestQA
(
ctx
context
.
Context
,
userID
*
int64
)
([]
QA
,
error
)
{
query
:=
`
query
:=
`
SELECT id, created_at, user_id, username, question, answer
SELECT id, created_at, user_id, username, question, answer
FROM latest_qa
FROM latest_qa
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
args
:=
[]
interface
{}{
userID
}
args
:=
[]
interface
{}{
userID
}
if
userID
==
nil
{
if
userID
==
nil
{
args
=
[]
interface
{}{
nil
}
args
=
[]
interface
{}{
nil
}
}
}
rows
,
err
:=
s
.
db
.
QueryContext
(
ctx
,
query
,
args
...
)
rows
,
err
:=
s
.
db
.
QueryContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query latest_qa: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"query latest_qa: %w"
,
err
)
}
}
defer
rows
.
Close
()
defer
rows
.
Close
()
var
results
[]
QA
var
results
[]
QA
for
rows
.
Next
()
{
for
rows
.
Next
()
{
var
qa
QA
var
qa
QA
var
userIDVal
sql
.
NullInt64
var
userIDVal
sql
.
NullInt64
var
username
,
question
,
answer
sql
.
NullString
var
username
,
question
,
answer
sql
.
NullString
if
err
:=
rows
.
Scan
(
&
qa
.
ID
,
&
qa
.
CreatedAt
,
&
userIDVal
,
&
username
,
&
question
,
&
answer
);
err
!=
nil
{
if
err
:=
rows
.
Scan
(
&
qa
.
ID
,
&
qa
.
CreatedAt
,
&
userIDVal
,
&
username
,
&
question
,
&
answer
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan row: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"scan row: %w"
,
err
)
}
}
if
userIDVal
.
Valid
{
if
userIDVal
.
Valid
{
qa
.
UserID
=
&
userIDVal
.
Int64
qa
.
UserID
=
&
userIDVal
.
Int64
}
}
if
username
.
Valid
{
if
username
.
Valid
{
qa
.
Username
=
&
username
.
String
qa
.
Username
=
&
username
.
String
}
}
if
question
.
Valid
{
if
question
.
Valid
{
qa
.
Question
=
&
question
.
String
qa
.
Question
=
&
question
.
String
}
}
if
answer
.
Valid
{
if
answer
.
Valid
{
qa
.
Answer
=
&
answer
.
String
qa
.
Answer
=
&
answer
.
String
}
}
results
=
append
(
results
,
qa
)
results
=
append
(
results
,
qa
)
}
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"row iteration: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"row iteration: %w"
,
err
)
}
}
return
results
,
nil
return
results
,
nil
}
}
// WriteQA 插入或更新 qa 表记录
// WriteQA 插入或更新 qa 表记录
func
(
s
*
qaStore
)
WriteQA
(
ctx
context
.
Context
,
qa
QA
)
(
int64
,
error
)
{
func
(
s
*
qaStore
)
WriteQA
(
ctx
context
.
Context
,
qa
QA
)
(
int64
,
error
)
{
if
qa
.
ID
!=
0
{
if
qa
.
ID
!=
0
{
// 更新记录
// 更新记录
query
:=
`
query
:=
`
UPDATE qa
UPDATE qa
SET user_id = $1, username = $2, question = $3, answer = $4
SET user_id = $1, username = $2, question = $3, answer = $4
WHERE id = $5
WHERE id = $5
RETURNING id`
RETURNING id`
var
updatedID
int64
var
updatedID
int64
err
:=
s
.
db
.
QueryRowContext
(
ctx
,
query
,
qa
.
UserID
,
qa
.
Username
,
qa
.
Question
,
qa
.
Answer
,
qa
.
ID
)
.
Scan
(
&
updatedID
)
err
:=
s
.
db
.
QueryRowContext
(
ctx
,
query
,
qa
.
UserID
,
qa
.
Username
,
qa
.
Question
,
qa
.
Answer
,
qa
.
ID
)
.
Scan
(
&
updatedID
)
if
err
==
sql
.
ErrNoRows
{
if
err
==
sql
.
ErrNoRows
{
return
0
,
fmt
.
Errorf
(
"no record found with id %d"
,
qa
.
ID
)
return
0
,
fmt
.
Errorf
(
"no record found with id %d"
,
qa
.
ID
)
}
}
if
err
!=
nil
{
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"update qa: %w"
,
err
)
return
0
,
fmt
.
Errorf
(
"update qa: %w"
,
err
)
}
}
return
updatedID
,
nil
return
updatedID
,
nil
}
}
// 插入新记录
// 插入新记录
query
:=
`
query
:=
`
INSERT INTO qa (user_id, username, question, answer)
INSERT INTO qa (user_id, username, question, answer)
VALUES ($1, $2, $3, $4)
VALUES ($1, $2, $3, $4)
RETURNING id`
RETURNING id`
var
newID
int64
var
newID
int64
err
:=
s
.
db
.
QueryRowContext
(
ctx
,
query
,
qa
.
UserID
,
qa
.
Username
,
qa
.
Question
,
qa
.
Answer
)
.
Scan
(
&
newID
)
err
:=
s
.
db
.
QueryRowContext
(
ctx
,
query
,
qa
.
UserID
,
qa
.
Username
,
qa
.
Question
,
qa
.
Answer
)
.
Scan
(
&
newID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"insert qa: %w"
,
err
)
return
0
,
fmt
.
Errorf
(
"insert qa: %w"
,
err
)
}
}
return
newID
,
nil
return
newID
,
nil
}
}
func
mainQA
()
{
func
mainQA
()
{
flag
.
Parse
()
flag
.
Parse
()
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
if
*
connString
==
""
{
if
*
connString
==
""
{
log
.
Fatal
(
"need -dbconn"
)
log
.
Fatal
(
"need -dbconn"
)
}
}
db
,
err
:=
sql
.
Open
(
"postgres"
,
*
connString
)
db
,
err
:=
sql
.
Open
(
"postgres"
,
*
connString
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatalf
(
"open database: %v"
,
err
)
log
.
Fatalf
(
"open database: %v"
,
err
)
}
}
defer
db
.
Close
()
defer
db
.
Close
()
store
:=
NewQAStore
(
db
)
store
:=
NewQAStore
(
db
)
// 示例:读取 user_id=101 的最新 QA
// 示例:读取 user_id=101 的最新 QA
results
,
err
:=
store
.
GetLatestQA
(
ctx
,
int64Ptr
(
101
))
results
,
err
:=
store
.
GetLatestQA
(
ctx
,
int64Ptr
(
101
))
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatalf
(
"get latest QA: %v"
,
err
)
log
.
Fatalf
(
"get latest QA: %v"
,
err
)
}
}
for
_
,
qa
:=
range
results
{
for
_
,
qa
:=
range
results
{
fmt
.
Printf
(
"ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v
\n
"
,
fmt
.
Printf
(
"ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v
\n
"
,
qa
.
ID
,
qa
.
CreatedAt
,
derefInt64
(
qa
.
UserID
),
derefString
(
qa
.
Username
),
derefString
(
qa
.
Question
),
derefString
(
qa
.
Answer
))
qa
.
ID
,
qa
.
CreatedAt
,
derefInt64
(
qa
.
UserID
),
derefString
(
qa
.
Username
),
derefString
(
qa
.
Question
),
derefString
(
qa
.
Answer
))
}
}
// 示例:插入新 QA
// 示例:插入新 QA
newQA
:=
QA
{
newQA
:=
QA
{
UserID
:
int64Ptr
(
101
),
UserID
:
int64Ptr
(
101
),
Username
:
stringPtr
(
"alice"
),
Username
:
stringPtr
(
"alice"
),
Question
:
stringPtr
(
"What is AI?"
),
Question
:
stringPtr
(
"What is AI?"
),
Answer
:
stringPtr
(
"AI is..."
),
Answer
:
stringPtr
(
"AI is..."
),
}
}
newID
,
err
:=
store
.
WriteQA
(
ctx
,
newQA
)
newID
,
err
:=
store
.
WriteQA
(
ctx
,
newQA
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatalf
(
"write QA: %v"
,
err
)
log
.
Fatalf
(
"write QA: %v"
,
err
)
}
}
fmt
.
Printf
(
"Inserted QA with ID: %d
\n
"
,
newID
)
fmt
.
Printf
(
"Inserted QA with ID: %d
\n
"
,
newID
)
// 示例:更新 QA
// 示例:更新 QA
updateQA
:=
QA
{
updateQA
:=
QA
{
ID
:
newID
,
ID
:
newID
,
UserID
:
int64Ptr
(
101
),
UserID
:
int64Ptr
(
101
),
Username
:
stringPtr
(
"alice_updated"
),
Username
:
stringPtr
(
"alice_updated"
),
Question
:
stringPtr
(
"What is NLP?"
),
Question
:
stringPtr
(
"What is NLP?"
),
Answer
:
stringPtr
(
"NLP is..."
),
Answer
:
stringPtr
(
"NLP is..."
),
}
}
updatedID
,
err
:=
store
.
WriteQA
(
ctx
,
updateQA
)
updatedID
,
err
:=
store
.
WriteQA
(
ctx
,
updateQA
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatalf
(
"update QA: %v"
,
err
)
log
.
Fatalf
(
"update QA: %v"
,
err
)
}
}
fmt
.
Printf
(
"Updated QA with ID: %d
\n
"
,
updatedID
)
fmt
.
Printf
(
"Updated QA with ID: %d
\n
"
,
updatedID
)
}
}
// 辅助函数:处理指针类型的空值
// 辅助函数:处理指针类型的空值
func
int64Ptr
(
i
int64
)
*
int64
{
func
int64Ptr
(
i
int64
)
*
int64
{
return
&
i
return
&
i
}
}
func
stringPtr
(
s
string
)
*
string
{
func
stringPtr
(
s
string
)
*
string
{
return
&
s
return
&
s
}
}
func
derefInt64
(
p
*
int64
)
interface
{}
{
func
derefInt64
(
p
*
int64
)
interface
{}
{
if
p
==
nil
{
if
p
==
nil
{
return
nil
return
nil
}
}
return
*
p
return
*
p
}
}
func
derefString
(
p
*
string
)
interface
{}
{
func
derefString
(
p
*
string
)
interface
{}
{
if
p
==
nil
{
if
p
==
nil
{
return
nil
return
nil
}
}
return
*
p
return
*
p
}
}
\ No newline at end of file
rate.go
0 → 100644
View file @
c3342c58
package
main
import
(
"context"
"net/http"
"sync"
"golang.org/x/time/rate"
)
// RateLimiter 定义限速器和并发队列
type
RateLimiter
struct
{
limiter
*
rate
.
Limiter
queue
chan
struct
{}
maxWorkers
int
mu
sync
.
Mutex
}
// NewRateLimiter 初始化限速器
func
NewRateLimiter
(
ratePerSecond
float64
,
burst
,
maxWorkers
int
)
*
RateLimiter
{
return
&
RateLimiter
{
limiter
:
rate
.
NewLimiter
(
rate
.
Limit
(
ratePerSecond
),
burst
),
queue
:
make
(
chan
struct
{},
maxWorkers
),
maxWorkers
:
maxWorkers
,
}
}
// Allow 检查是否允许请求
func
(
rl
*
RateLimiter
)
Allow
(
ctx
context
.
Context
)
bool
{
rl
.
mu
.
Lock
()
defer
rl
.
mu
.
Unlock
()
if
err
:=
rl
.
limiter
.
Wait
(
ctx
);
err
!=
nil
{
return
false
}
select
{
case
rl
.
queue
<-
struct
{}{}
:
return
true
default
:
return
false
}
}
// Release 释放并发槽
func
(
rl
*
RateLimiter
)
Release
()
{
<-
rl
.
queue
}
// Middleware HTTP 中间件
func
(
rl
*
RateLimiter
)
Middleware
(
next
http
.
Handler
)
http
.
Handler
{
return
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
ctx
:=
r
.
Context
()
if
!
rl
.
Allow
(
ctx
)
{
http
.
Error
(
w
,
"Too Many Requests"
,
http
.
StatusTooManyRequests
)
return
}
defer
rl
.
Release
()
next
.
ServeHTTP
(
w
,
r
)
})
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment