Commit 4b2ea618 authored by Wade's avatar Wade

add ollama param

parent 8c52cd01
......@@ -9,16 +9,18 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge"
"github.com/wade-liwei/agentchat/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/question"
"github.com/wade-liwei/agentchat/util"
"google.golang.org/genai"
// Import knowledge package
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/ollama"
"github.com/rs/zerolog/log"
_ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档
)
......
......@@ -12,8 +12,8 @@ import (
"github.com/wade-liwei/agentchat/plugins/graphrag" // Import knowledge package
"github.com/wade-liwei/agentchat/plugins/knowledge"
"github.com/wade-liwei/agentchat/plugins/milvus"
"github.com/wade-liwei/agentchat/plugins/ollama"
"github.com/firebase/genkit/go/plugins/evaluators"
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/server"
......@@ -80,25 +80,15 @@ func main() {
os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
ctx := context.Background()
metrics := []evaluators.MetricConfig{
{
MetricType: evaluators.EvaluatorDeepEqual,
},
{
MetricType: evaluators.EvaluatorRegex,
},
{
MetricType: evaluators.EvaluatorJsonata,
},
}
// Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"},
&deepseek.DeepSeek{APIKey: *deepseekAPIKey},
&milvus.Milvus{Addr: *milvusAddr},
&graphrag.GraphKnowledge{Addr: *graphragAddr},
&googlegenai.GoogleAI{APIKey: *googleAIApiKey},
&evaluators.GenkitEval{Metrics: metrics},
))
if err != nil {
......
This diff is collapsed.
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
type EmbedOptions struct {
Model string `json:"model"`
}
type ollamaEmbedRequest struct {
Model string `json:"model"`
Input any `json:"input"` // todo: using any to handle both string and []string, figure out better solution
Options map[string]any `json:"options,omitempty"`
}
type ollamaEmbedResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}
func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
options, ok := req.Options.(*EmbedOptions)
if !ok && req.Options != nil {
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions")
}
if options == nil || options.Model == "" {
return nil, fmt.Errorf("invalid embedding model: model must be specified")
}
if serverAddress == "" {
return nil, fmt.Errorf("invalid server address: address cannot be empty")
}
ollamaReq := newOllamaEmbedRequest(options.Model, req.Input)
jsonData, err := json.Marshal(ollamaReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal embed request: %w", err)
}
resp, err := sendEmbedRequest(ctx, serverAddress, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode)
}
var ollamaResp ollamaEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return nil, fmt.Errorf("failed to decode embed response: %w", err)
}
return newEmbedResponse(ollamaResp.Embeddings), nil
}
func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) {
client := &http.Client{}
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
return client.Do(httpReq)
}
func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest {
var input any
if len(documents) == 1 {
input = concatenateText(documents[0])
} else {
texts := make([]string, len(documents))
for i, doc := range documents {
texts[i] = concatenateText(doc)
}
input = texts
}
return ollamaEmbedRequest{
Model: model,
Input: input,
}
}
func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse {
resp := &ai.EmbedResponse{
Embeddings: make([]*ai.Embedding, len(embeddings)),
}
for i, embedding := range embeddings {
resp.Embeddings[i] = &ai.Embedding{Embedding: embedding}
}
return resp
}
func concatenateText(doc *ai.Document) string {
var builder strings.Builder
for _, part := range doc.Content {
builder.WriteString(part.Text)
}
result := builder.String()
return result
}
// DefineEmbedder defines an embedder with a given server address.
func (o *Ollama) DefineEmbedder(g *genkit.Genkit, serverAddress string, model string) ai.Embedder {
o.mu.Lock()
defer o.mu.Unlock()
if !o.initted {
panic("ollama.Init not called")
}
return genkit.DefineEmbedder(g, provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
if req.Options == nil {
req.Options = &EmbedOptions{Model: model}
}
if req.Options.(*EmbedOptions).Model == "" {
req.Options.(*EmbedOptions).Model = model
}
return embed(ctx, serverAddress, req)
})
}
// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin.
func IsDefinedEmbedder(g *genkit.Genkit, serverAddress string) bool {
return genkit.LookupEmbedder(g, provider, serverAddress) != nil
}
// Embedder returns the [ai.Embedder] with the given server address.
// It returns nil if the embedder was not defined.
func Embedder(g *genkit.Genkit, serverAddress string) ai.Embedder {
return genkit.LookupEmbedder(g, provider, serverAddress)
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/firebase/genkit/go/ai"
)
func TestEmbedValidRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(ollamaEmbedResponse{
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
})
}))
defer server.Close()
req := &ai.EmbedRequest{
Input: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}
resp, err := embed(context.Background(), server.URL, req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(resp.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings))
}
}
func TestEmbedInvalidServerAddress(t *testing.T) {
req := &ai.EmbedRequest{
Input: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}
_, err := embed(context.Background(), "", req)
if err == nil || !strings.Contains(err.Error(), "invalid server address") {
t.Fatalf("expected invalid server address error, got %v", err)
}
}
This diff is collapsed.
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama_test
import (
"context"
"flag"
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
ollamaPlugin "github.com/firebase/genkit/go/plugins/ollama"
)
var serverAddress = flag.String("server-address", "http://localhost:11434", "Ollama server address")
var modelName = flag.String("model-name", "tinyllama", "model name")
var testLive = flag.Bool("test-live", false, "run live tests")
/*
To run this test, you need to have the Ollama server running. You can set the server address using the OLLAMA_SERVER_ADDRESS environment variable.
If the environment variable is not set, the test will default to http://localhost:11434 (the default address for the Ollama server).
*/
func TestLive(t *testing.T) {
if !*testLive {
t.Skip("skipping go/plugins/ollama live test")
}
ctx := context.Background()
g, err := genkit.Init(context.Background())
if err != nil {
t.Fatal(err)
}
o := &ollamaPlugin.Ollama{ServerAddress: *serverAddress}
// Initialize the Ollama plugin
if err = o.Init(ctx, g); err != nil {
t.Fatalf("failed to initialize Ollama plugin: %s", err)
}
// Define the model
o.DefineModel(g, ollamaPlugin.ModelDefinition{Name: *modelName}, nil)
// Use the Ollama model
m := ollamaPlugin.Model(g, *modelName)
if m == nil {
t.Fatalf(`failed to find model: %s`, *modelName)
}
// Generate a response from the model
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 1}),
ai.WithPrompt("I'm hungry, what should I eat?"),
)
if err != nil {
t.Fatalf("failed to generate response: %s", err)
}
if resp == nil {
t.Fatalf("response is nil")
}
// Get the text from the response
text := resp.Text()
// log.Println("Response:", text)
// Assert that the response text is as expected
if text == "" {
t.Fatalf("expected non-empty response, got: %s", text)
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
var _ genkit.Plugin = (*Ollama)(nil)
func TestConcatMessages(t *testing.T) {
tests := []struct {
name string
messages []*ai.Message
roles []ai.Role
want string
}{
{
name: "Single message with matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")},
},
},
roles: []ai.Role{ai.RoleUser},
want: "Hello, how are you?",
},
{
name: "Multiple messages with mixed roles",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")},
},
{
Role: ai.RoleModel,
Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")},
},
},
roles: []ai.Role{ai.RoleModel},
want: "Why did the scarecrow win an award? Because he was outstanding in his field!",
},
{
name: "No matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Any suggestions?")},
},
},
roles: []ai.Role{ai.RoleSystem},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := &ai.ModelRequest{Messages: tt.messages}
got := concatMessages(input, tt.roles)
if got != tt.want {
t.Errorf("concatMessages() = %q, want %q", got, tt.want)
}
})
}
}
func TestTranslateGenerateChunk(t *testing.T) {
tests := []struct {
name string
input string
want *ai.ModelResponseChunk
wantErr bool
}{
{
name: "Valid JSON response",
input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`,
want: &ai.ModelResponseChunk{
Content: []*ai.Part{ai.NewTextPart("This is a test response.")},
},
wantErr: false,
},
{
name: "Invalid JSON",
input: `{invalid}`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := translateGenerateChunk(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !equalContent(got.Content, tt.want.Content) {
t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want)
}
})
}
}
// Helper function to compare content
func equalContent(a, b []*ai.Part) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Text != b[i].Text || !a[i].IsText() || !b[i].IsText() {
return false
}
}
return true
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment