Commit 96aff5b7 authored by Wade's avatar Wade

api ok

parent c3342c58
...@@ -87,3 +87,5 @@ require ( ...@@ -87,3 +87,5 @@ require (
google.golang.org/protobuf v1.36.6 // indirect google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )
replace github.com/firebase/genkit/go => ../genkit/go
...@@ -13,6 +13,15 @@ import ( ...@@ -13,6 +13,15 @@ import (
"github.com/firebase/genkit/go/plugins/server" "github.com/firebase/genkit/go/plugins/server"
) )
type Input struct {
Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Username string `json:"username,omitempty"`
UserID string `json:"user_id,omitempty"`
}
func main() { func main() {
ctx := context.Background() ctx := context.Background()
...@@ -34,7 +43,10 @@ func main() { ...@@ -34,7 +43,10 @@ func main() {
nil) nil)
// Define a simple flow that generates jokes about a given topic // Define a simple flow that generates jokes about a given topic
genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { genkit.DefineFlow(g, "chat", func(ctx context.Context, input *Input) (string, error) {
fmt.Println("input-------------------------------",input.Content)
resp, err := genkit.Generate(ctx, g, resp, err := genkit.Generate(ctx, g,
ai.WithModel(m), ai.WithModel(m),
ai.WithPrompt(`Tell silly short jokes about apple`)) ai.WithPrompt(`Tell silly short jokes about apple`))
...@@ -66,8 +78,8 @@ func main() { ...@@ -66,8 +78,8 @@ func main() {
} }
// 启动服务器,监听 // 启动服务器,监听
log.Printf("Server starting on 0.0.0.0:3400") log.Printf("Server starting on 0.0.0.0:8000")
if err := server.Start(ctx, "0.0.0.0:3400", mux); err != nil { if err := server.Start(ctx, "0.0.0.0:8000", mux); err != nil {
log.Fatalf("Server failed: %v", err) log.Fatalf("Server failed: %v", err)
} }
} }
package main
import (
"context"
"fmt"
"log"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/googlegenai"
)
// ModelRouter 定义模型路由器
type ModelRouter struct {
models map[string]func(apiKey string) ai.Model
}
// NewModelRouter 初始化模型路由器
func NewModelRouter() *ModelRouter {
return &ModelRouter{
models: map[string]func(apiKey string) ai.Model{
"gemini-1.5-flash": func(apiKey string) ai.Model {
return googlegenai.GoogleAIModel(g, "gemini-2.5-pro-preview-03-25")
},
"deepseek": func(apiKey string) ai.Model {
// 假设 DeepSeek 插件实现
// 示例:return deepseek.Model("deepseek-model", deepseek.WithAPIKey(apiKey))
log.Printf("DeepSeek not implemented, using placeholder")
return nil // 替换为实际 DeepSeek 插件
},
},
}
}
// GetModel 根据模型名称和 API 密钥获取模型
func (mr *ModelRouter) GetModel(modelName, apiKey string) (ai.Model, error) {
if modelName == "" {
return nil, fmt.Errorf("model parameter is required")
}
modelFunc, ok := mr.models[modelName]
if !ok {
return nil, fmt.Errorf("unsupported model: %s", modelName)
}
model := modelFunc(apiKey)
if model == nil {
return nil, fmt.Errorf("failed to initialize model: %s", modelName)
}
return model, nil
}
type Input struct {
Content []*ai.Part `json:"content,omitempty"`
Model string `json:"model,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Username string `json:"username,omitempty"`
UserID string `json:"user_id,omitempty"`
}
func defineChatFlow(g *genkit.Genkit, mr *ModelRouter) {
genkit.DefineFlow(g, "chat", func(ctx context.Context, input Input) (string, error) {
// Get the model
model, err := mr.GetModel(input.Model, input.APIKey)
if err != nil {
return "", fmt.Errorf("failed to get model: %v", err)
}
// Add user context to the prompt if available
userContext := ""
if input.Username != "" {
userContext = fmt.Sprintf("User %s ", input.Username)
}
if input.UserID != "" {
userContext += fmt.Sprintf("(ID: %s) ", input.UserID)
}
resp, err := genkit.Generate(ctx, g,
ai.WithModel(model),
ai.WithPrompt(fmt.Sprintf("%sasks: %s", userContext, input.Content)))
if err != nil {
fmt.Println(err.Error())
return "", err
}
fmt.Println("resp.Text()", resp.Text())
return resp.Text(), nil
})
}
...@@ -273,7 +273,9 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer { ...@@ -273,7 +273,9 @@ func Indexer(g *genkit.Genkit, collection string) ai.Indexer {
func Retriever(g *genkit.Genkit, collection string) ai.Retriever { func Retriever(g *genkit.Genkit, collection string) ai.Retriever {
return genkit.LookupRetriever(g, provider, collection) return genkit.LookupRetriever(g, provider, collection)
} }
/*
更新 删除 很少用到;
*/
// Index implements the Indexer.Index method. // Index implements the Indexer.Index method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
if len(req.Documents) == 0 { if len(req.Documents) == 0 {
...@@ -388,6 +390,8 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai ...@@ -388,6 +390,8 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err) return nil, fmt.Errorf("NewIndexHNSWSearchParam failed: %v", err)
} }
// TODO 元数据 过滤条件
// Perform search. // Perform search.
results, err := ds.client.Search( results, err := ds.client.Search(
ctx, ctx,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment