Commit 4b2ea618 authored by Wade's avatar Wade

add ollama param

parent 8c52cd01
......@@ -9,16 +9,18 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
//"github.com/firebase/genkit/go/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/deepseek"
"github.com/wade-liwei/agentchat/plugins/graphrag"
"github.com/wade-liwei/agentchat/plugins/knowledge"
"github.com/wade-liwei/agentchat/plugins/ollama"
"github.com/wade-liwei/agentchat/plugins/question"
"github.com/wade-liwei/agentchat/util"
"google.golang.org/genai"
// Import knowledge package
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/ollama"
"github.com/rs/zerolog/log"
_ "github.com/wade-liwei/agentchat/docs" // 导入生成的 Swagger 文档
)
......
......@@ -12,8 +12,8 @@ import (
"github.com/wade-liwei/agentchat/plugins/graphrag" // Import knowledge package
"github.com/wade-liwei/agentchat/plugins/knowledge"
"github.com/wade-liwei/agentchat/plugins/milvus"
"github.com/wade-liwei/agentchat/plugins/ollama"
"github.com/firebase/genkit/go/plugins/evaluators"
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/server"
......@@ -80,25 +80,15 @@ func main() {
os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
ctx := context.Background()
metrics := []evaluators.MetricConfig{
{
MetricType: evaluators.EvaluatorDeepEqual,
},
{
MetricType: evaluators.EvaluatorRegex,
},
{
MetricType: evaluators.EvaluatorJsonata,
},
}
// Initialize genkit with plugins using flag/env values
g, err := genkit.Init(ctx, genkit.WithPlugins(
&ollama.Ollama{ServerAddress: "http://localhost:11434"},
&deepseek.DeepSeek{APIKey: *deepseekAPIKey},
&milvus.Milvus{Addr: *milvusAddr},
&graphrag.GraphKnowledge{Addr: *graphragAddr},
&googlegenai.GoogleAI{APIKey: *googleAIApiKey},
&evaluators.GenkitEval{Metrics: metrics},
))
if err != nil {
......
......@@ -277,427 +277,3 @@ func concatMessageParts(parts []*ai.Part) string {
Msg("Concatenation complete")
return result
}
// package deepseek
// import (
// "context"
// "fmt"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// "github.com/rs/zerolog/log"
// deepseek "github.com/cohesion-org/deepseek-go"
// )
// const provider = "deepseek"
// var (
// mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// roleMapping = map[ai.Role]string{
// ai.RoleUser: deepseek.ChatMessageRoleUser,
// ai.RoleModel: deepseek.ChatMessageRoleAssistant,
// ai.RoleSystem: deepseek.ChatMessageRoleSystem,
// ai.RoleTool: deepseek.ChatMessageRoleTool,
// }
// )
// // DeepSeek holds configuration for the plugin.
// type DeepSeek struct {
// APIKey string // DeepSeek API key
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the provider name.
// func (d DeepSeek) Name() string {
// return provider
// }
// // ModelDefinition represents a model with its name and type.
// type ModelDefinition struct {
// Name string
// Type string
// }
// // DefineModel defines a DeepSeek model in Genkit.
// func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// log.Info().
// Str("method", "DeepSeek.DefineModel").
// Str("model_name", model.Name).
// Msg("Defining DeepSeek model")
// d.mu.Lock()
// defer d.mu.Unlock()
// if !d.initted {
// log.Error().Str("method", "DeepSeek.DefineModel").Msg("DeepSeek not initialized")
// panic("deepseek.Init not called")
// }
// // Define model info, supporting multiturn and system role.
// mi := ai.ModelInfo{
// Label: model.Name,
// Supports: &ai.ModelSupports{
// Multiturn: true,
// SystemRole: true,
// Media: false, // DeepSeek API primarily supports text.
// Tools: false, // Tools not yet supported in this implementation.
// },
// Versions: []string{},
// }
// if info != nil {
// mi = *info
// }
// meta := &ai.ModelInfo{
// Label: model.Name,
// Supports: mi.Supports,
// Versions: []string{},
// }
// gen := &generator{model: model, apiKey: d.APIKey}
// modelDef := genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
// log.Info().
// Str("method", "DeepSeek.DefineModel").
// Str("model_name", model.Name).
// Msg("Model defined successfully")
// return modelDef
// }
// // Init initializes the DeepSeek plugin.
// func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
// log.Info().Str("method", "DeepSeek.Init").Msg("Initializing DeepSeek plugin")
// d.mu.Lock()
// defer d.mu.Unlock()
// if d.initted {
// log.Error().Str("method", "DeepSeek.Init").Msg("Plugin already initialized")
// return fmt.Errorf("deepseek.Init already called")
// }
// if d == nil || d.APIKey == "" {
// log.Error().Str("method", "DeepSeek.Init").Msg("APIKey is required")
// return fmt.Errorf("deepseek: need APIKey")
// }
// d.initted = true
// log.Info().Str("method", "DeepSeek.Init").Msg("Initialization successful")
// return nil
// }
// // generator handles model generation.
// type generator struct {
// model ModelDefinition
// apiKey string
// }
// // generate implements the Genkit model generation interface.
// func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// log.Info().
// Str("method", "generator.generate").
// Str("model_name", g.model.Name).
// Int("messages", len(input.Messages)).
// Msg("Starting model generation")
// if len(input.Messages) == 0 {
// log.Error().Str("method", "generator.generate").Msg("Prompt or messages required")
// return nil, fmt.Errorf("prompt or messages required")
// }
// // Initialize DeepSeek client.
// client := deepseek.NewClient(g.apiKey)
// log.Debug().Str("method", "generator.generate").Msg("DeepSeek client initialized")
// // Create a chat completion request
// request := &deepseek.ChatCompletionRequest{
// Model: g.model.Name,
// }
// for _, msg := range input.Messages {
// role, ok := roleMapping[msg.Role]
// if !ok {
// log.Error().
// Str("method", "generator.generate").
// Str("role", string(msg.Role)).
// Msg("Unsupported role")
// return nil, fmt.Errorf("unsupported role: %s", msg.Role)
// }
// content := concatMessageParts(msg.Content)
// request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
// Role: role,
// Content: content,
// })
// log.Debug().
// Str("method", "generator.generate").
// Str("role", role).
// Str("content", content).
// Msg("Added message to request")
// }
// // Send the request and handle the response
// response, err := client.CreateChatCompletion(ctx, request)
// if err != nil {
// log.Error().
// Err(err).
// Str("method", "generator.generate").
// Msg("Failed to create chat completion")
// return nil, fmt.Errorf("create chat completion: %w", err)
// }
// log.Debug().
// Str("method", "generator.generate").
// Int("choices", len(response.Choices)).
// Msg("Received chat completion response")
// // Create a final response with the merged chunks
// finalResponse := &ai.ModelResponse{
// Request: input,
// FinishReason: ai.FinishReason("stop"),
// Message: &ai.Message{
// Role: ai.RoleModel,
// },
// }
// for _, chunk := range response.Choices {
// log.Debug().
// Str("method", "generator.generate").
// Int("index", chunk.Index).
// Str("content", chunk.Message.Content).
// Msg("Processing response chunk")
// p := ai.Part{
// Text: chunk.Message.Content,
// Kind: ai.PartKind(chunk.Index),
// }
// finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
// }
// log.Info().
// Str("method", "generator.generate").
// Str("model_name", g.model.Name).
// Int("content_parts", len(finalResponse.Message.Content)).
// Msg("Model generation completed successfully")
// return finalResponse, nil
// }
// // concatMessageParts concatenates message parts into a single string.
// func concatMessageParts(parts []*ai.Part) string {
// log.Debug().
// Str("method", "concatMessageParts").
// Int("parts", len(parts)).
// Msg("Concatenating message parts")
// var sb strings.Builder
// for _, part := range parts {
// if part.IsText() {
// sb.WriteString(part.Text)
// }
// // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
// }
// result := sb.String()
// log.Debug().
// Str("method", "concatMessageParts").
// Str("result", result).
// Msg("Concatenation complete")
// return result
// }
// package deepseek
// import (
// "context"
// "fmt"
// "log"
// "strings"
// "sync"
// "github.com/firebase/genkit/go/ai"
// "github.com/firebase/genkit/go/genkit"
// deepseek "github.com/cohesion-org/deepseek-go"
// )
// const provider = "deepseek"
// var (
// mediaSupportedModels = []string{deepseek.DeepSeekChat, deepseek.DeepSeekCoder, deepseek.DeepSeekReasoner}
// // toolSupportedModels = []string{
// // "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
// // "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
// // "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
// // "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
// // "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
// // "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
// // "granite3.3", "command-a", "command-r7b-arabic",
// // }
// roleMapping = map[ai.Role]string{
// ai.RoleUser: deepseek.ChatMessageRoleUser,
// ai.RoleModel: deepseek.ChatMessageRoleAssistant,
// ai.RoleSystem: deepseek.ChatMessageRoleSystem,
// ai.RoleTool: deepseek.ChatMessageRoleTool,
// }
// )
// // DeepSeek holds configuration for the plugin.
// type DeepSeek struct {
// APIKey string // DeepSeek API key
// //ServerAddress string
// mu sync.Mutex // Mutex to control access.
// initted bool // Whether the plugin has been initialized.
// }
// // Name returns the provider name.
// func (d DeepSeek) Name() string {
// return provider
// }
// // ModelDefinition represents a model with its name and type.
// type ModelDefinition struct {
// Name string
// Type string
// }
// // // DefineModel defines a DeepSeek model in Genkit.
// func (d *DeepSeek) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// d.mu.Lock()
// defer d.mu.Unlock()
// if !d.initted {
// panic("deepseek.Init not called")
// }
// // Define model info, supporting multiturn and system role.
// mi := ai.ModelInfo{
// Label: model.Name,
// Supports: &ai.ModelSupports{
// Multiturn: true,
// SystemRole: true,
// Media: false, // DeepSeek API primarily supports text.
// Tools: false, // Tools not yet supported in this implementation.
// },
// Versions: []string{},
// }
// if info != nil {
// mi = *info
// }
// meta := &ai.ModelInfo{
// // Label: "DeepSeek - " + model.Name,
// Label: model.Name,
// Supports: mi.Supports,
// Versions: []string{},
// }
// gen := &generator{model: model, apiKey: d.APIKey}
// return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
// }
// // Init initializes the DeepSeek plugin.
// func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) error {
// d.mu.Lock()
// defer d.mu.Unlock()
// if d.initted {
// panic("deepseek.Init already called")
// }
// if d == nil || d.APIKey == "" {
// return fmt.Errorf("deepseek: need APIKey")
// }
// d.initted = true
// return nil
// }
// // generator handles model generation.
// type generator struct {
// model ModelDefinition
// apiKey string
// }
// // generate implements the Genkit model generation interface.
// func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// // stream := cb != nil
// if len(input.Messages) == 0 {
// return nil, fmt.Errorf("prompt or messages required")
// }
// // Set up the Deepseek client
// // Initialize DeepSeek client.
// client := deepseek.NewClient(g.apiKey)
// // Create a chat completion request
// request := &deepseek.ChatCompletionRequest{
// Model: g.model.Name,
// }
// for _, msg := range input.Messages {
// role, ok := roleMapping[msg.Role]
// if !ok {
// return nil, fmt.Errorf("unsupported role: %s", msg.Role)
// }
// content := concatMessageParts(msg.Content)
// request.Messages = append(request.Messages, deepseek.ChatCompletionMessage{
// Role: role,
// Content: content,
// })
// }
// // Send the request and handle the response
// response, err := client.CreateChatCompletion(ctx, request)
// if err != nil {
// log.Fatalf("error: %v", err)
// }
// // Print the response
// fmt.Println("Response:", response.Choices[0].Message.Content)
// // Create a final response with the merged chunks
// finalResponse := &ai.ModelResponse{
// Request: input,
// FinishReason: ai.FinishReason("stop"),
// Message: &ai.Message{
// Role: ai.RoleModel,
// },
// }
// for _, chunk := range response.Choices {
// p := ai.Part{
// Text: chunk.Message.Content,
// Kind: ai.PartKind(chunk.Index),
// }
// finalResponse.Message.Content = append(finalResponse.Message.Content, &p)
// }
// return finalResponse, nil // Return the final merged response
// }
// // concatMessageParts concatenates message parts into a single string.
// func concatMessageParts(parts []*ai.Part) string {
// var sb strings.Builder
// for _, part := range parts {
// if part.IsText() {
// sb.WriteString(part.Text)
// }
// // Ignore non-text parts (e.g., media, tools) as DeepSeek API doesn't support them.
// }
// return sb.String()
// }
/*
// Choice represents a completion choice generated by the model.
type Choice struct {
Index int `json:"index"` // Index of the choice in the list of choices.
Message Message `json:"message"` // The message generated by the model.
Logprobs any `json:"logprobs,omitempty"` // Log probabilities of the tokens, if available. // Changed to any as of April 21 2025 because the logprobs field is sometimes a flot64 and sometimes a Logprobs struct.
FinishReason string `json:"finish_reason"` // Reason why the completion finished.
}
// A Part is one part of a [Document]. This may be plain text or it
// may be a URL (possibly a "data:" URL with embedded data).
type Part struct {
Kind PartKind `json:"kind,omitempty"`
ContentType string `json:"contentType,omitempty"` // valid for kind==blob
Text string `json:"text,omitempty"` // valid for kind∈{text,blob}
ToolRequest *ToolRequest `json:"toolRequest,omitempty"` // valid for kind==partToolRequest
ToolResponse *ToolResponse `json:"toolResponse,omitempty"` // valid for kind==partToolResponse
Custom map[string]any `json:"custom,omitempty"` // valid for plugin-specific custom parts
Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds
}
*/
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
type EmbedOptions struct {
Model string `json:"model"`
}
type ollamaEmbedRequest struct {
Model string `json:"model"`
Input any `json:"input"` // todo: using any to handle both string and []string, figure out better solution
Options map[string]any `json:"options,omitempty"`
}
type ollamaEmbedResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}
func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
options, ok := req.Options.(*EmbedOptions)
if !ok && req.Options != nil {
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions")
}
if options == nil || options.Model == "" {
return nil, fmt.Errorf("invalid embedding model: model must be specified")
}
if serverAddress == "" {
return nil, fmt.Errorf("invalid server address: address cannot be empty")
}
ollamaReq := newOllamaEmbedRequest(options.Model, req.Input)
jsonData, err := json.Marshal(ollamaReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal embed request: %w", err)
}
resp, err := sendEmbedRequest(ctx, serverAddress, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode)
}
var ollamaResp ollamaEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return nil, fmt.Errorf("failed to decode embed response: %w", err)
}
return newEmbedResponse(ollamaResp.Embeddings), nil
}
func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) {
client := &http.Client{}
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
return client.Do(httpReq)
}
func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest {
var input any
if len(documents) == 1 {
input = concatenateText(documents[0])
} else {
texts := make([]string, len(documents))
for i, doc := range documents {
texts[i] = concatenateText(doc)
}
input = texts
}
return ollamaEmbedRequest{
Model: model,
Input: input,
}
}
func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse {
resp := &ai.EmbedResponse{
Embeddings: make([]*ai.Embedding, len(embeddings)),
}
for i, embedding := range embeddings {
resp.Embeddings[i] = &ai.Embedding{Embedding: embedding}
}
return resp
}
func concatenateText(doc *ai.Document) string {
var builder strings.Builder
for _, part := range doc.Content {
builder.WriteString(part.Text)
}
result := builder.String()
return result
}
// DefineEmbedder defines an embedder with a given server address.
func (o *Ollama) DefineEmbedder(g *genkit.Genkit, serverAddress string, model string) ai.Embedder {
o.mu.Lock()
defer o.mu.Unlock()
if !o.initted {
panic("ollama.Init not called")
}
return genkit.DefineEmbedder(g, provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
if req.Options == nil {
req.Options = &EmbedOptions{Model: model}
}
if req.Options.(*EmbedOptions).Model == "" {
req.Options.(*EmbedOptions).Model = model
}
return embed(ctx, serverAddress, req)
})
}
// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin.
func IsDefinedEmbedder(g *genkit.Genkit, serverAddress string) bool {
return genkit.LookupEmbedder(g, provider, serverAddress) != nil
}
// Embedder returns the [ai.Embedder] with the given server address.
// It returns nil if the embedder was not defined.
func Embedder(g *genkit.Genkit, serverAddress string) ai.Embedder {
return genkit.LookupEmbedder(g, provider, serverAddress)
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/firebase/genkit/go/ai"
)
func TestEmbedValidRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(ollamaEmbedResponse{
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
})
}))
defer server.Close()
req := &ai.EmbedRequest{
Input: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}
resp, err := embed(context.Background(), server.URL, req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(resp.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings))
}
}
func TestEmbedInvalidServerAddress(t *testing.T) {
req := &ai.EmbedRequest{
Input: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}
_, err := embed(context.Background(), "", req)
if err == nil || !strings.Contains(err.Error(), "invalid server address") {
t.Fatalf("expected invalid server address error, got %v", err)
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"slices"
"strings"
"sync"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/internal/uri"
"github.com/rs/zerolog/log"
)
const provider = "ollama"
const Provider = "ollama"
var (
mediaSupportedModels = []string{"llava", "bakllava", "llava-llama3", "llava:13b", "llava:7b", "llava:latest"}
toolSupportedModels = []string{
"qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
"qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
"mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
"phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
"nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
"granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
"granite3.3", "command-a", "command-r7b-arabic",
}
roleMapping = map[ai.Role]string{
ai.RoleUser: "user",
ai.RoleModel: "assistant",
ai.RoleSystem: "system",
ai.RoleTool: "tool",
}
)
// ListModels returns a map of media-supported models and their capabilities
func ListModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(mediaSupportedModels))
for _, modelName := range mediaSupportedModels {
// Normalize model name by removing version tags (e.g., "llava:13b" -> "llava")
baseName := strings.Split(modelName, ":")[0]
models[modelName] = ai.ModelInfo{
Label: "Ollama - " + baseName,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: true, // All models in mediaSupportedModels support media
Tools: false, // None of these models are in toolSupportedModels
},
Versions: []string{},
}
}
return models, nil
}
func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
// o.mu.Lock()
// defer o.mu.Unlock()
if !o.initted {
panic("ollama.Init not called")
}
var mi ai.ModelInfo
if info != nil {
mi = *info
} else {
// Check if the model supports tools (must be a chat model and in the supported list)
supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name)
mi = ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: slices.Contains(mediaSupportedModels, model.Name),
Tools: supportsTools,
},
Versions: []string{},
}
}
meta := &ai.ModelInfo{
Label: "Ollama - " + model.Name,
Supports: mi.Supports,
Versions: []string{},
}
gen := &generator{model: model, serverAddress: o.ServerAddress}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
}
// IsDefinedModel reports whether a model is defined.
func IsDefinedModel(g *genkit.Genkit, name string) bool {
return genkit.LookupModel(g, provider, name) != nil
}
// Model returns the [ai.Model] with the given name.
// It returns nil if the model was not configured.
func Model(g *genkit.Genkit, name string) ai.Model {
return genkit.LookupModel(g, provider, name)
}
// ModelDefinition represents a model with its name and type.
type ModelDefinition struct {
Name string
Type string
}
type generator struct {
model ModelDefinition
serverAddress string
}
type ollamaMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
}
// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
// That's why have multiple request interfaces for the Ollama API below.
/*
TODO: Support optional, advanced parameters:
format: the format to return a response in. Currently the only accepted value is json
options: additional model parameters listed in the documentation for the Modelfile such as temperature
system: system message to (overrides what is defined in the Modelfile)
template: the prompt template to use (overrides what is defined in the Modelfile)
context: the context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory
stream: if false the response will be returned as a single response object, rather than a stream of objects
raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API
keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
*/
type ollamaChatRequest struct {
Messages []*ollamaMessage `json:"messages"`
Images []string `json:"images,omitempty"`
Model string `json:"model"`
Stream bool `json:"stream"`
Format string `json:"format,omitempty"`
Tools []ollamaTool `json:"tools,omitempty"`
}
type ollamaModelRequest struct {
System string `json:"system,omitempty"`
Images []string `json:"images,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Format string `json:"format,omitempty"`
}
// Tool definition from Ollama API
type ollamaTool struct {
Type string `json:"type"`
Function ollamaFunction `json:"function"`
}
// Function definition for Ollama API
type ollamaFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
}
// Tool Call from Ollama API
type ollamaToolCall struct {
Function ollamaFunctionCall `json:"function"`
}
// Function Call for Ollama API
type ollamaFunctionCall struct {
Name string `json:"name"`
Arguments any `json:"arguments"`
}
// TODO: Add optional parameters (images, format, options, etc.) based on your use case
type ollamaChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
}
type ollamaModelResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
}
// Ollama provides configuration options for the Init function.
type Ollama struct {
ServerAddress string // Server address of oLLama.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
func (o *Ollama) Name() string {
return provider
}
// Init initializes the plugin.
// Since Ollama models are locally hosted, the plugin doesn't initialize any default models.
// After downloading a model, call [DefineModel] to use it.
func (o *Ollama) Init(ctx context.Context, g *genkit.Genkit) (err error) {
o.mu.Lock()
defer o.mu.Unlock()
if o.initted {
panic("ollama.Init already called")
}
if o == nil || o.ServerAddress == "" {
return errors.New("ollama: need ServerAddress")
}
o.initted = true
// Register all supported models
modelSet := make(map[string]struct{})
for _, modelName := range mediaSupportedModels {
modelSet[modelName] = struct{}{}
}
for _, modelName := range toolSupportedModels {
modelSet[modelName] = struct{}{}
}
for modelName := range modelSet {
// Determine model type
modelType := "chat"
if slices.Contains(mediaSupportedModels, modelName) && !slices.Contains(toolSupportedModels, modelName) {
modelType = "vision"
}
modelDef := ModelDefinition{
Name: modelName,
Type: modelType,
}
o.DefineModel(g, modelDef, nil)
log.Info().
Str("method", "Ollama.Init").
Str("model_name", modelName).
Str("model_type", modelType).
Msg("Registered model")
}
log.Info().Str("method", "Ollama.Init").Msg("Initialization successful")
return nil
}
// Generate makes a request to the Ollama API and processes the response.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
stream := cb != nil
var payload any
isChatModel := g.model.Type == "chat"
// Check if this is an image model
hasMediaSupport := slices.Contains(mediaSupportedModels, g.model.Name)
// Extract images if the model supports them
var images []string
var err error
if hasMediaSupport {
images, err = concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel})
if err != nil {
return nil, fmt.Errorf("failed to grab image parts: %v", err)
}
}
if !isChatModel {
payload = ollamaModelRequest{
Model: g.model.Name,
Prompt: concatMessages(input, []ai.Role{ai.RoleUser, ai.RoleModel, ai.RoleTool}),
System: concatMessages(input, []ai.Role{ai.RoleSystem}),
Images: images,
Stream: stream,
}
} else {
var messages []*ollamaMessage
// Translate all messages to ollama message format.
for _, m := range input.Messages {
message, err := convertParts(m.Role, m.Content)
if err != nil {
return nil, fmt.Errorf("failed to convert message parts: %v", err)
}
messages = append(messages, message)
}
chatReq := ollamaChatRequest{
Messages: messages,
Model: g.model.Name,
Stream: stream,
Images: images,
}
if len(input.Tools) > 0 {
tools, err := convertTools(input.Tools)
if err != nil {
return nil, fmt.Errorf("failed to convert tools: %v", err)
}
chatReq.Tools = tools
}
payload = chatReq
}
client := &http.Client{Timeout: 30 * time.Second}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
// Determine the correct endpoint
endpoint := g.serverAddress + "/api/chat"
if !isChatModel {
endpoint = g.serverAddress + "/api/generate"
}
req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
if cb == nil {
// Existing behavior for non-streaming responses
var err error
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body)
}
var response *ai.ModelResponse
if isChatModel {
response, err = translateChatResponse(body)
} else {
response, err = translateModelResponse(body)
}
response.Request = input
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
return response, nil
} else {
var chunks []*ai.ModelResponseChunk
scanner := bufio.NewScanner(resp.Body)
chunkCount := 0
for scanner.Scan() {
line := scanner.Text()
chunkCount++
var chunk *ai.ModelResponseChunk
if isChatModel {
chunk, err = translateChatChunk(line)
} else {
chunk, err = translateGenerateChunk(line)
}
if err != nil {
return nil, fmt.Errorf("failed to translate chunk: %v", err)
}
chunks = append(chunks, chunk)
cb(ctx, chunk)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading response stream: %v", err)
}
// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
// Add all the merged content to the final response's candidate
for _, chunk := range chunks {
finalResponse.Message.Content = append(finalResponse.Message.Content, chunk.Content...)
}
return finalResponse, nil // Return the final merged response
}
}
// convertTools converts Genkit tool definitions to Ollama tool format
func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) {
ollamaTools := make([]ollamaTool, 0, len(tools))
for _, tool := range tools {
ollamaTools = append(ollamaTools, ollamaTool{
Type: "function",
Function: ollamaFunction{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.InputSchema,
},
})
}
return ollamaTools, nil
}
func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
message := &ollamaMessage{
Role: roleMapping[role],
}
var contentBuilder strings.Builder
var toolCalls []ollamaToolCall
var images []string
for _, part := range parts {
if part.IsText() {
contentBuilder.WriteString(part.Text)
} else if part.IsMedia() {
_, data, err := uri.Data(part)
if err != nil {
return nil, fmt.Errorf("failed to extract media data: %v", err)
}
base64Encoded := base64.StdEncoding.EncodeToString(data)
images = append(images, base64Encoded)
} else if part.IsToolRequest() {
toolReq := part.ToolRequest
toolCalls = append(toolCalls, ollamaToolCall{
Function: ollamaFunctionCall{
Name: toolReq.Name,
Arguments: toolReq.Input,
},
})
} else if part.IsToolResponse() {
toolResp := part.ToolResponse
outputJSON, err := json.Marshal(toolResp.Output)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool response: %v", err)
}
contentBuilder.WriteString(string(outputJSON))
} else {
return nil, errors.New("unsupported content type")
}
}
message.Content = contentBuilder.String()
if len(toolCalls) > 0 {
message.ToolCalls = toolCalls
}
if len(images) > 0 {
message.Images = images
}
return message, nil
}
// translateChatResponse translates Ollama chat response into a genkit response.
func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
var response ollamaChatResponse
if err := json.Unmarshal(responseData, &response); err != nil {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
modelResponse := &ai.ModelResponse{
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
}
return modelResponse, nil
}
// translateModelResponse translates Ollama generate response into a genkit response.
func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) {
var response ollamaModelResponse
if err := json.Unmarshal(responseData, &response); err != nil {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
modelResponse := &ai.ModelResponse{
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.RoleModel,
},
}
aiPart := ai.NewTextPart(response.Response)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
modelResponse.Usage = &ai.GenerationUsage{} // TODO: can we get any of this info?
return modelResponse, nil
}
func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
var response ollamaChatResponse
if err := json.Unmarshal([]byte(input), &response); err != nil {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
chunk := &ai.ModelResponseChunk{}
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
chunk.Content = append(chunk.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
}
return chunk, nil
}
func translateGenerateChunk(input string) (*ai.ModelResponseChunk, error) {
var response ollamaModelResponse
if err := json.Unmarshal([]byte(input), &response); err != nil {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
chunk := &ai.ModelResponseChunk{}
aiPart := ai.NewTextPart(response.Response)
chunk.Content = append(chunk.Content, aiPart)
return chunk, nil
}
// concatMessages translates a list of messages into a prompt-style format
func concatMessages(input *ai.ModelRequest, roles []ai.Role) string {
roleSet := make(map[ai.Role]bool)
for _, role := range roles {
roleSet[role] = true // Create a set for faster lookup
}
var sb strings.Builder
for _, message := range input.Messages {
// Check if the message role is in the allowed set
if !roleSet[message.Role] {
continue
}
for _, part := range message.Content {
if !part.IsText() {
continue
}
sb.WriteString(part.Text)
}
}
return sb.String()
}
// concatImages grabs the images from genkit message parts
func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error) {
roleSet := make(map[ai.Role]bool)
for _, role := range roleFilter {
roleSet[role] = true
}
var images []string
for _, message := range input.Messages {
// Check if the message role is in the allowed set
if roleSet[message.Role] {
for _, part := range message.Content {
if !part.IsMedia() {
continue
}
// Get the media type and data
mediaType, data, err := uri.Data(part)
if err != nil {
return nil, fmt.Errorf("failed to extract image data: %v", err)
}
// Only include image media types
if !strings.HasPrefix(mediaType, "image/") {
continue
}
base64Encoded := base64.StdEncoding.EncodeToString(data)
images = append(images, base64Encoded)
}
}
}
return images, nil
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama_test
import (
"context"
"flag"
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
ollamaPlugin "github.com/firebase/genkit/go/plugins/ollama"
)
var serverAddress = flag.String("server-address", "http://localhost:11434", "Ollama server address")
var modelName = flag.String("model-name", "tinyllama", "model name")
var testLive = flag.Bool("test-live", false, "run live tests")
/*
To run this test, you need to have the Ollama server running. You can set the server address using the OLLAMA_SERVER_ADDRESS environment variable.
If the environment variable is not set, the test will default to http://localhost:11434 (the default address for the Ollama server).
*/
func TestLive(t *testing.T) {
if !*testLive {
t.Skip("skipping go/plugins/ollama live test")
}
ctx := context.Background()
g, err := genkit.Init(context.Background())
if err != nil {
t.Fatal(err)
}
o := &ollamaPlugin.Ollama{ServerAddress: *serverAddress}
// Initialize the Ollama plugin
if err = o.Init(ctx, g); err != nil {
t.Fatalf("failed to initialize Ollama plugin: %s", err)
}
// Define the model
o.DefineModel(g, ollamaPlugin.ModelDefinition{Name: *modelName}, nil)
// Use the Ollama model
m := ollamaPlugin.Model(g, *modelName)
if m == nil {
t.Fatalf(`failed to find model: %s`, *modelName)
}
// Generate a response from the model
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 1}),
ai.WithPrompt("I'm hungry, what should I eat?"),
)
if err != nil {
t.Fatalf("failed to generate response: %s", err)
}
if resp == nil {
t.Fatalf("response is nil")
}
// Get the text from the response
text := resp.Text()
// log.Println("Response:", text)
// Assert that the response text is as expected
if text == "" {
t.Fatalf("expected non-empty response, got: %s", text)
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
var _ genkit.Plugin = (*Ollama)(nil)
func TestConcatMessages(t *testing.T) {
tests := []struct {
name string
messages []*ai.Message
roles []ai.Role
want string
}{
{
name: "Single message with matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")},
},
},
roles: []ai.Role{ai.RoleUser},
want: "Hello, how are you?",
},
{
name: "Multiple messages with mixed roles",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")},
},
{
Role: ai.RoleModel,
Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")},
},
},
roles: []ai.Role{ai.RoleModel},
want: "Why did the scarecrow win an award? Because he was outstanding in his field!",
},
{
name: "No matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Any suggestions?")},
},
},
roles: []ai.Role{ai.RoleSystem},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := &ai.ModelRequest{Messages: tt.messages}
got := concatMessages(input, tt.roles)
if got != tt.want {
t.Errorf("concatMessages() = %q, want %q", got, tt.want)
}
})
}
}
func TestTranslateGenerateChunk(t *testing.T) {
tests := []struct {
name string
input string
want *ai.ModelResponseChunk
wantErr bool
}{
{
name: "Valid JSON response",
input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`,
want: &ai.ModelResponseChunk{
Content: []*ai.Part{ai.NewTextPart("This is a test response.")},
},
wantErr: false,
},
{
name: "Invalid JSON",
input: `{invalid}`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := translateGenerateChunk(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !equalContent(got.Content, tt.want.Content) {
t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want)
}
})
}
}
// Helper function to compare content
func equalContent(a, b []*ai.Part) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Text != b[i].Text || !a[i].IsText() || !b[i].IsText() {
return false
}
}
return true
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment