Commit 93bfed3e authored by Wade's avatar Wade

add genkit

parent e76d4bb2
Pipeline #864 failed with stages
# Genkit
This package is the Go version of Genkit, a framework for building
AI-powered apps. See: https://genkit.dev/go/docs/get-started-go
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"context"
"os"
"testing"
"github.com/firebase/genkit/go/internal/registry"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
type specSuite struct {
Tests []testCase `yaml:"tests"`
}
type testCase struct {
Name string `yaml:"name"`
Input *GenerateActionOptions `yaml:"input"`
StreamChunks [][]*ModelResponseChunk `yaml:"streamChunks,omitempty"`
ModelResponses []*ModelResponse `yaml:"modelResponses"`
ExpectResponse *ModelResponse `yaml:"expectResponse,omitempty"`
Stream bool `yaml:"stream,omitempty"`
ExpectChunks []*ModelResponseChunk `yaml:"expectChunks,omitempty"`
}
type programmableModel struct {
r *registry.Registry
handleResp func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error)
lastRequest *ModelRequest
}
func (pm *programmableModel) Name() string {
return "programmableModel"
}
func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) {
// Make a copy of the request to modify for testing purposes
if req != nil && req.Tools != nil {
for _, tool := range req.Tools {
if tool.Name == "testTool" {
// Set the schema fields directly
tool.InputSchema = map[string]any{"$schema": "http://json-schema.org/draft-07/schema#"}
tool.OutputSchema = map[string]any{"$schema": "http://json-schema.org/draft-07/schema#"}
}
}
}
pm.lastRequest = req
return pm.handleResp(ctx, req, cb)
}
func defineProgrammableModel(r *registry.Registry) *programmableModel {
pm := &programmableModel{r: r}
supports := &ModelSupports{
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
}
func TestGenerateAction(t *testing.T) {
data, err := os.ReadFile("../../tests/specs/generate.yaml")
if err != nil {
t.Fatalf("failed to read spec file: %v", err)
}
var suite specSuite
if err := yaml.Unmarshal(data, &suite); err != nil {
t.Fatalf("failed to parse spec file: %v", err)
}
for _, tc := range suite.Tests {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()
r, err := registry.New()
if err != nil {
t.Fatalf("failed to create registry: %v", err)
}
ConfigureFormats(r)
pm := defineProgrammableModel(r)
DefineTool(r, "testTool", "description",
func(ctx *ToolContext, input any) (any, error) {
return "tool called", nil
})
if len(tc.ModelResponses) > 0 || len(tc.StreamChunks) > 0 {
reqCounter := 0
pm.handleResp = func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) {
if len(tc.StreamChunks) > 0 && cb != nil {
for _, chunk := range tc.StreamChunks[reqCounter] {
if err := cb(ctx, chunk); err != nil {
return nil, err
}
}
}
resp := tc.ModelResponses[reqCounter]
resp.Request = req
resp.Custom = map[string]any{}
resp.Request.Output = &ModelOutputConfig{}
resp.Usage = &GenerationUsage{}
reqCounter++
return resp, nil
}
}
genAction := DefineGenerateAction(ctx, r)
if tc.Stream {
chunks := []*ModelResponseChunk{}
streamCb := func(ctx context.Context, chunk *ModelResponseChunk) error {
chunks = append(chunks, chunk)
return nil
}
resp, err := genAction.Run(ctx, tc.Input, streamCb)
if err != nil {
t.Fatalf("action failed: %v", err)
}
if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" {
t.Errorf("chunks mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
} else {
resp, err := genAction.Run(ctx, tc.Input, nil)
if err != nil {
t.Fatalf("action failed: %v", err)
}
if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
}
})
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"encoding/json"
"fmt"
)
// A Document is a piece of data that can be embedded, indexed, or retrieved.
// It includes metadata. It can contain multiple parts.
type Document struct {
// The data that is part of this document.
Content []*Part `json:"content,omitempty"`
// The metadata for this document.
Metadata map[string]any `json:"metadata,omitempty"`
}
// A Part is one part of a [Document]. This may be plain text or it
// may be a URL (possibly a "data:" URL with embedded data).
type Part struct {
Kind PartKind `json:"kind,omitempty"`
ContentType string `json:"contentType,omitempty"` // valid for kind==blob
Text string `json:"text,omitempty"` // valid for kind∈{text,blob}
ToolRequest *ToolRequest `json:"toolRequest,omitempty"` // valid for kind==partToolRequest
ToolResponse *ToolResponse `json:"toolResponse,omitempty"` // valid for kind==partToolResponse
Custom map[string]any `json:"custom,omitempty"` // valid for plugin-specific custom parts
Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds
}
type PartKind int8
const (
PartText PartKind = iota
PartMedia
PartData
PartToolRequest
PartToolResponse
PartCustom
PartReasoning
)
// NewTextPart returns a Part containing text.
func NewTextPart(text string) *Part {
return &Part{Kind: PartText, ContentType: "plain/text", Text: text}
}
// NewJSONPart returns a Part containing JSON.
func NewJSONPart(text string) *Part {
return &Part{Kind: PartText, ContentType: "application/json", Text: text}
}
// NewMediaPart returns a Part containing structured data described
// by the given mimeType.
func NewMediaPart(mimeType, contents string) *Part {
return &Part{Kind: PartMedia, ContentType: mimeType, Text: contents}
}
// NewDataPart returns a Part containing raw string data.
func NewDataPart(contents string) *Part {
return &Part{Kind: PartData, Text: contents}
}
// NewToolRequestPart returns a Part containing a request from
// the model to the client to run a Tool.
// (Only genkit plugins should need to use this function.)
func NewToolRequestPart(r *ToolRequest) *Part {
return &Part{Kind: PartToolRequest, ToolRequest: r}
}
// NewToolResponsePart returns a Part containing the results
// of applying a Tool that the model requested.
func NewToolResponsePart(r *ToolResponse) *Part {
return &Part{Kind: PartToolResponse, ToolResponse: r}
}
// NewCustomPart returns a Part containing custom plugin-specific data.
func NewCustomPart(customData map[string]any) *Part {
return &Part{Kind: PartCustom, Custom: customData}
}
// NewReasoningPart returns a Part containing reasoning text
func NewReasoningPart(text string) *Part {
return &Part{Kind: PartReasoning, ContentType: "plain/text", Text: text}
}
// IsText reports whether the [Part] contains plain text.
func (p *Part) IsText() bool {
return p.Kind == PartText
}
// IsMedia reports whether the [Part] contains structured media data.
func (p *Part) IsMedia() bool {
return p.Kind == PartMedia
}
// IsData reports whether the [Part] contains unstructured data.
func (p *Part) IsData() bool {
return p.Kind == PartData
}
// IsToolRequest reports whether the [Part] contains a request to run a tool.
func (p *Part) IsToolRequest() bool {
return p.Kind == PartToolRequest
}
// IsToolResponse reports whether the [Part] contains the result of running a tool.
func (p *Part) IsToolResponse() bool {
return p.Kind == PartToolResponse
}
// IsCustom reports whether the [Part] contains custom plugin-specific data.
func (p *Part) IsCustom() bool {
return p.Kind == PartCustom
}
// IsReasoning reports whether the [Part] contains a reasoning text
func (p *Part) IsReasoning() bool {
return p.Kind == PartReasoning
}
// MarshalJSON is called by the JSON marshaler to write out a Part.
func (p *Part) MarshalJSON() ([]byte, error) {
// This is not handled by the schema generator because
// Part is defined in TypeScript as a union.
switch p.Kind {
case PartText:
v := textPart{
Text: p.Text,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartMedia:
v := mediaPart{
Media: &Media{
ContentType: p.ContentType,
Url: p.Text,
},
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartData:
v := dataPart{
Data: p.Text,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartToolRequest:
v := toolRequestPart{
ToolRequest: p.ToolRequest,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartToolResponse:
v := toolResponsePart{
ToolResponse: p.ToolResponse,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartCustom:
v := customPart{
Custom: p.Custom,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartReasoning:
v := reasoningPart{
Reasoning: p.Text,
Metadata: p.Metadata,
}
return json.Marshal(v)
default:
return nil, fmt.Errorf("invalid part kind %v", p.Kind)
}
}
type partSchema struct {
Text string `json:"text,omitempty" yaml:"text,omitempty"`
Media *Media `json:"media,omitempty" yaml:"media,omitempty"`
Data string `json:"data,omitempty" yaml:"data,omitempty"`
ToolRequest *ToolRequest `json:"toolRequest,omitempty" yaml:"toolRequest,omitempty"`
ToolResponse *ToolResponse `json:"toolResponse,omitempty" yaml:"toolResponse,omitempty"`
Custom map[string]any `json:"custom,omitempty" yaml:"custom,omitempty"`
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty"`
Reasoning string `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
}
// unmarshalPartFromSchema updates Part p based on the schema s.
func (p *Part) unmarshalPartFromSchema(s partSchema) {
switch {
case s.Media != nil:
p.Kind = PartMedia
p.Text = s.Media.Url
p.ContentType = s.Media.ContentType
case s.ToolRequest != nil:
p.Kind = PartToolRequest
p.ToolRequest = s.ToolRequest
case s.ToolResponse != nil:
p.Kind = PartToolResponse
p.ToolResponse = s.ToolResponse
case s.Custom != nil:
p.Kind = PartCustom
p.Custom = s.Custom
default:
p.Kind = PartText
p.Text = s.Text
p.ContentType = ""
if s.Data != "" {
// Note: if part is completely empty, we use text by default.
p.Kind = PartData
p.Text = s.Data
}
}
p.Metadata = s.Metadata
}
// UnmarshalJSON is called by the JSON unmarshaler to read a Part.
func (p *Part) UnmarshalJSON(b []byte) error {
var s partSchema
if err := json.Unmarshal(b, &s); err != nil {
return err
}
p.unmarshalPartFromSchema(s)
return nil
}
// UnmarshalYAML implements goccy/go-yaml library's InterfaceUnmarshaler interface.
func (p *Part) UnmarshalYAML(unmarshal func(any) error) error {
var s partSchema
if err := unmarshal(&s); err != nil {
return err
}
p.unmarshalPartFromSchema(s)
return nil
}
// JSONSchemaAlias tells the JSON schema reflection code to use a different
// type for the schema for this type. This is needed because the JSON
// marshaling of Part uses a schema that matches the TypeScript code,
// rather than the natural JSON marshaling. This matters because the
// current JSON validation code works by marshaling the JSON.
func (Part) JSONSchemaAlias() any {
return partSchema{}
}
// DocumentFromText returns a [Document] containing a single plain text part.
// This takes ownership of the metadata map.
func DocumentFromText(text string, metadata map[string]any) *Document {
return &Document{
Content: []*Part{
{
Kind: PartText,
Text: text,
},
},
Metadata: metadata,
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"encoding/json"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestDocumentFromText(t *testing.T) {
const data = "robot overlord"
d := DocumentFromText(data, nil)
if len(d.Content) != 1 {
t.Fatalf("got %d parts, want 1", len(d.Content))
}
p := d.Content[0]
if !p.IsText() {
t.Errorf("IsText() == %t, want %t", p.IsText(), true)
}
if got := p.Text; got != data {
t.Errorf("Data() == %q, want %q", got, data)
}
}
// TODO: verify that this works with the data that genkit passes.
func TestDocumentJSON(t *testing.T) {
d := Document{
Content: []*Part{
&Part{
Kind: PartText,
Text: "hi",
},
&Part{
Kind: PartMedia,
ContentType: "text/plain",
Text: "data:,bye",
},
&Part{
Kind: PartData,
Text: "somedata\x00string",
},
&Part{
Kind: PartToolRequest,
ToolRequest: &ToolRequest{
Name: "tool1",
Input: map[string]any{"arg1": 3.3, "arg2": "foo"},
},
},
&Part{
Kind: PartToolResponse,
ToolResponse: &ToolResponse{
Name: "tool1",
Output: map[string]any{"res1": 4.4, "res2": "bar"},
},
},
},
}
b, err := json.Marshal(&d)
if err != nil {
t.Fatal(err)
}
t.Logf("marshaled:%s\n", string(b))
var d2 Document
if err := json.Unmarshal(b, &d2); err != nil {
t.Fatal(err)
}
cmpPart := func(a, b *Part) bool {
if a.Kind != b.Kind {
return false
}
switch a.Kind {
case PartText:
return a.Text == b.Text
case PartMedia:
return a.ContentType == b.ContentType && a.Text == b.Text
case PartData:
return a.Text == b.Text
case PartToolRequest:
return reflect.DeepEqual(a.ToolRequest, b.ToolRequest)
case PartToolResponse:
return reflect.DeepEqual(a.ToolResponse, b.ToolResponse)
default:
t.Fatalf("bad part kind %v", a.Kind)
return false
}
}
diff := cmp.Diff(d, d2, cmp.Comparer(cmpPart))
if diff != "" {
t.Errorf("mismatch (-want, +got)\n%s", diff)
}
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"context"
"fmt"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
)
// Embedder represents an embedder that can perform content embedding.
type Embedder interface {
// Name returns the registry name of the embedder.
Name() string
// Embed embeds to content as part of the [EmbedRequest].
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
}
// An embedder is used to convert a document to a multidimensional vector.
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] that runs it.
func DefineEmbedder(
r *registry.Registry,
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
}
// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, atype.Embedder, provider, name)
if action == nil {
return nil
}
return (*embedder)(action)
}
// Name returns the name of the embedder.
func (e *embedder) Name() string {
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Name()
}
// Embed runs the given [Embedder].
func (e *embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Run(ctx, req, nil)
}
// Embed invokes the embedder with provided options.
func Embed(ctx context.Context, e Embedder, opts ...EmbedderOption) (*EmbedResponse, error) {
embedOpts := &embedderOptions{}
for _, opt := range opts {
if err := opt.applyEmbedder(embedOpts); err != nil {
return nil, fmt.Errorf("ai.Embed: error applying options: %w", err)
}
}
req := &EmbedRequest{
Input: embedOpts.Documents,
Options: embedOpts.Config,
}
return e.Embed(ctx, req)
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"context"
"errors"
"fmt"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
)
// Evaluator represents a evaluator action.
type Evaluator interface {
// Name returns the name of the evaluator.
Name() string
// Evaluates a dataset.
Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error)
}
type evaluator core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
// Example is a single example that requires evaluation
type Example struct {
TestCaseId string `json:"testCaseId,omitempty"`
Input any `json:"input"`
Output any `json:"output,omitempty"`
Context []any `json:"context,omitempty"`
Reference any `json:"reference,omitempty"`
TraceIds []string `json:"traceIds,omitempty"`
}
// EvaluatorRequest is the data we pass to evaluate a dataset.
// The Options field is specific to the actual evaluator implementation.
type EvaluatorRequest struct {
Dataset []*Example `json:"dataset"`
EvaluationId string `json:"evalRunId"`
Options any `json:"options,omitempty"`
}
// ScoreStatus is an enum used to indicate if a Score has passed or failed. This
// drives additional features in tooling / the Dev UI.
type ScoreStatus int
const (
ScoreStatusUnknown ScoreStatus = iota
ScoreStatusFail
ScoreStatusPass
)
var statusName = map[ScoreStatus]string{
ScoreStatusUnknown: "UNKNOWN",
ScoreStatusFail: "FAIL",
ScoreStatusPass: "PASS",
}
func (ss ScoreStatus) String() string {
return statusName[ss]
}
// Score is the evaluation score that represents the result of an evaluator.
// This struct includes information such as the score (numeric, string or other
// types), the reasoning provided for this score (if any), the score status (if
// any) and other details.
type Score struct {
Id string `json:"id,omitempty"`
Score any `json:"score,omitempty"`
Status string `json:"status,omitempty" jsonschema:"enum=UNKNOWN,enum=FAIL,enum=PASS"`
Error string `json:"error,omitempty"`
Details map[string]any `json:"details,omitempty"`
}
// EvaluationResult is the result of running the evaluator on a single Example.
// An evaluator may provide multiple scores simultaneously (e.g. if they are using
// an API to score on multiple criteria)
type EvaluationResult struct {
TestCaseId string `json:"testCaseId"`
TraceID string `json:"traceId,omitempty"`
SpanID string `json:"spanId,omitempty"`
Evaluation []Score `json:"evaluation"`
}
// EvaluatorResponse is a collection of [EvaluationResult] structs, it
// represents the result on the entire input dataset.
type EvaluatorResponse = []EvaluationResult
type EvaluatorOptions struct {
DisplayName string `json:"displayName"`
Definition string `json:"definition"`
IsBilled bool `json:"isBilled,omitempty"`
}
// EvaluatorCallbackRequest is the data we pass to the callback function
// provided in defineEvaluator. The Options field is specific to the actual
// evaluator implementation.
type EvaluatorCallbackRequest struct {
Input Example `json:"input"`
Options any `json:"options,omitempty"`
}
// EvaluatorCallbackResponse is the result on evaluating a single [Example]
type EvaluatorCallbackResponse = EvaluationResult
// DefineEvaluator registers the given evaluator function as an action, and
// returns a [Evaluator] that runs it. This method process the input dataset
// one-by-one.
func DefineEvaluator(r *registry.Registry, provider, name string, options *EvaluatorOptions, eval func(context.Context, *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error)) (Evaluator, error) {
if options == nil {
return nil, errors.New("EvaluatorOptions must be provided")
}
// TODO(ssbushi): Set this on `evaluator` key on action metadata
metadataMap := map[string]any{}
metadataMap["evaluatorIsBilled"] = options.IsBilled
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition
actionDef := (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
var evalResponses []EvaluationResult
for _, datapoint := range req.Dataset {
if datapoint.TestCaseId == "" {
datapoint.TestCaseId = uuid.New().String()
}
_, err := tracing.RunInNewSpan(ctx, r.TracingState(), fmt.Sprintf("TestCase %s", datapoint.TestCaseId), "evaluator", false, datapoint,
func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) {
traceId := trace.SpanContextFromContext(ctx).TraceID().String()
spanId := trace.SpanContextFromContext(ctx).SpanID().String()
callbackRequest := EvaluatorCallbackRequest{
Input: *input,
Options: req.Options,
}
evaluatorResponse, err := eval(ctx, &callbackRequest)
if err != nil {
failedScore := Score{
Status: ScoreStatusFail.String(),
Error: fmt.Sprintf("Evaluation of test case %s failed: \n %s", input.TestCaseId, err.Error()),
}
failedEvalResult := EvaluationResult{
TestCaseId: input.TestCaseId,
Evaluation: []Score{failedScore},
TraceID: traceId,
SpanID: spanId,
}
evalResponses = append(evalResponses, failedEvalResult)
// return error to mark span as failed
return nil, err
}
evaluatorResponse.TraceID = traceId
evaluatorResponse.SpanID = spanId
evalResponses = append(evalResponses, *evaluatorResponse)
return evaluatorResponse, nil
})
if err != nil {
logger.FromContext(ctx).Debug("EvaluatorAction", "err", err)
continue
}
}
return &evalResponses, nil
}))
return actionDef, nil
}
// DefineBatchEvaluator registers the given evaluator function as an action, and
// returns a [Evaluator] that runs it. This method provide the full
// [EvaluatorRequest] to the callback function, giving more flexibilty to the
// user for processing the data, such as batching or parallelization.
func DefineBatchEvaluator(r *registry.Registry, provider, name string, options *EvaluatorOptions, batchEval func(context.Context, *EvaluatorRequest) (*EvaluatorResponse, error)) (Evaluator, error) {
if options == nil {
return nil, errors.New("EvaluatorOptions must be provided")
}
metadataMap := map[string]any{}
metadataMap["evaluatorIsBilled"] = options.IsBilled
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition
return (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
}
// LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator].
// It returns nil if the evaluator was not defined.
func LookupEvaluator(r *registry.Registry, provider, name string) Evaluator {
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
}
// Evaluate calls the retrivers with provided options.
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluatorOption) (*EvaluatorResponse, error) {
evalOpts := &evaluatorOptions{}
for _, opt := range opts {
err := opt.applyEvaluator(evalOpts)
if err != nil {
return nil, err
}
}
req := &EvaluatorRequest{
Dataset: evalOpts.Dataset,
EvaluationId: evalOpts.ID,
Options: evalOpts.Config,
}
return r.Evaluate(ctx, req)
}
// Name returns the name of the evaluator.
func (e evaluator) Name() string {
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Name()
}
// Evaluate runs the given [Evaluator].
func (e evaluator) Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Run(ctx, req, nil)
}
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"context"
"errors"
"fmt"
"testing"
"github.com/firebase/genkit/go/internal/registry"
)
var testEvalFunc = func(ctx context.Context, req *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error) {
m := make(map[string]any)
m["reasoning"] = "No good reason"
m["options"] = req.Options
score := Score{
Id: "testScore",
Score: 1,
Status: ScoreStatusPass.String(),
Details: m,
}
callbackResponse := EvaluatorCallbackResponse{
TestCaseId: req.Input.TestCaseId,
Evaluation: []Score{score},
}
return &callbackResponse, nil
}
var testBatchEvalFunc = func(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
var evalResponses []EvaluationResult
for _, datapoint := range req.Dataset {
fmt.Printf("%+v\n", datapoint)
m := make(map[string]any)
m["reasoning"] = fmt.Sprintf("batch of cookies, %s", datapoint.Input)
m["options"] = req.Options
score := Score{
Id: "testScore",
Score: true,
Status: ScoreStatusPass.String(),
Details: m,
}
callbackResponse := EvaluationResult{
TestCaseId: datapoint.TestCaseId,
Evaluation: []Score{score},
}
evalResponses = append(evalResponses, callbackResponse)
}
return &evalResponses, nil
}
var testFailingEvalFunc = func(ctx context.Context, req *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error) {
return nil, errors.New("i give up")
}
var evalOptions = EvaluatorOptions{
DisplayName: "Test Evaluator",
Definition: "Returns pass score for all",
IsBilled: false,
}
var dataset = []*Example{
{
Input: "hello world",
},
{
Input: "Foo bar",
},
}
var testRequest = EvaluatorRequest{
Dataset: dataset,
EvaluationId: "testrun",
Options: "test-options",
}
func TestSimpleEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
evalAction, err := DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testEvalFunc)
if err != nil {
t.Fatal(err)
}
resp, err := evalAction.Evaluate(context.Background(), &testRequest)
if err != nil {
t.Fatal(err)
}
if got, want := len(*resp), 2; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Id, "testScore"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Score, 1; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Status, "PASS"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Details["options"], "test-options"; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func TestOptionsRequired(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
_, err = DefineEvaluator(r, "test", "testEvaluator", nil, testEvalFunc)
if err == nil {
t.Errorf("expected error, got nil")
}
_, err = DefineBatchEvaluator(r, "test", "testBatchEvaluator", nil, testBatchEvalFunc)
if err == nil {
t.Errorf("expected error, got nil")
}
}
func TestFailingEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
evalAction, err := DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testFailingEvalFunc)
if err != nil {
t.Fatal(err)
}
resp, err := evalAction.Evaluate(context.Background(), &testRequest)
if err != nil {
t.Fatal(err)
}
if got, dontWant := (*resp)[0].Evaluation[0].Error, ""; got == dontWant {
t.Errorf("got %v, dontWant %v", got, dontWant)
}
if got, want := (*resp)[0].Evaluation[0].Status, "FAIL"; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func TestLookupEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
evalAction, err := DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testEvalFunc)
if err != nil {
t.Fatal(err)
}
batchEvalAction, err := DefineBatchEvaluator(r, "test", "testBatchEvaluator", &evalOptions, testBatchEvalFunc)
if err != nil {
t.Fatal(err)
}
if got, want := LookupEvaluator(r, "test", "testEvaluator"), evalAction; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := LookupEvaluator(r, "test", "testBatchEvaluator"), batchEvalAction; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func TestEvaluate(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
evalAction, err := DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testEvalFunc)
if err != nil {
t.Fatal(err)
}
resp, err := Evaluate(context.Background(), evalAction,
WithDataset(dataset...),
WithID("testrun"),
WithConfig("test-options"))
if err != nil {
t.Fatal(err)
}
if got, want := (*resp)[0].Evaluation[0].Id, "testScore"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Score, 1; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Status, "PASS"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Details["options"], "test-options"; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
func TestBatchEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}
evalAction, err := DefineBatchEvaluator(r, "test", "testBatchEvaluator", &evalOptions, testBatchEvalFunc)
if err != nil {
t.Fatal(err)
}
resp, err := evalAction.Evaluate(context.Background(), &testRequest)
if err != nil {
t.Fatal(err)
}
if got, want := len(*resp), 2; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Id, "testScore"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Score, true; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Status, "PASS"; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := (*resp)[0].Evaluation[0].Details["options"], "test-options"; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"encoding/json"
"errors"
"fmt"
"github.com/firebase/genkit/go/internal/base"
)
type arrayFormatter struct{}
// Name returns the name of the formatter.
func (a arrayFormatter) Name() string {
return OutputFormatArray
}
// Handler returns a new formatter handler for the given schema.
func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
if schema == nil || !base.ValidateIsJSONArray(schema) {
return nil, fmt.Errorf("schema is not valid JSON array")
}
jsonBytes, err := json.Marshal(schema["items"])
if err != nil {
return nil, fmt.Errorf("error marshalling schema to JSON, must supply an 'array' schema type when using the 'array' parser format.: %w", err)
}
instructions := fmt.Sprintf("Output should be a JSON array conforming to the following schema:\n\n```%s```", string(jsonBytes))
handler := &arrayHandler{
instructions: instructions,
config: ModelOutputConfig{
Format: OutputFormatArray,
Schema: schema,
ContentType: "application/json",
},
}
return handler, nil
}
type arrayHandler struct {
instructions string
config ModelOutputConfig
}
// Instructions returns the instructions for the formatter.
func (a arrayHandler) Instructions() string {
return a.instructions
}
// Config returns the output config for the formatter.
func (a arrayHandler) Config() ModelOutputConfig {
return a.config
}
// ParseMessage parses the message and returns the formatted message.
func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
if a.config.Format == OutputFormatArray {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}
var newParts []*Part
for _, part := range m.Content {
if !part.IsText() {
newParts = append(newParts, part)
} else {
lines := base.GetJsonObjectLines(part.Text)
for _, line := range lines {
var schemaBytes []byte
schemaBytes, err := json.Marshal(a.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}
newParts = append(newParts, NewJSONPart(line))
}
}
}
m.Content = newParts
}
return m, nil
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"errors"
"fmt"
"regexp"
"slices"
"strings"
)
type enumFormatter struct{}
// Name returns the name of the formatter.
func (e enumFormatter) Name() string {
return OutputFormatEnum
}
// Handler returns a new formatter handler for the given schema.
func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
enums := objectEnums(schema)
if schema == nil || len(enums) == 0 {
return nil, fmt.Errorf("schema is not valid JSON enum")
}
instructions := fmt.Sprintf("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n```%s```", strings.Join(enums, "\n"))
handler := &enumHandler{
instructions: instructions,
config: ModelOutputConfig{
Format: OutputFormatEnum,
Schema: schema,
ContentType: "text/enum",
},
enums: enums,
}
return handler, nil
}
type enumHandler struct {
instructions string
config ModelOutputConfig
enums []string
}
// Instructions returns the instructions for the formatter.
func (e enumHandler) Instructions() string {
return e.instructions
}
// Config returns the output config for the formatter.
func (e enumHandler) Config() ModelOutputConfig {
return e.config
}
// ParseMessage parses the message and returns the formatted message.
func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
if e.config.Format == OutputFormatEnum {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}
for i, part := range m.Content {
if !part.IsText() {
continue
}
// replace single and double quotes
re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(part.Text, "")
// trim whitespace
trimmed := strings.TrimSpace(clean)
if !slices.Contains(e.enums, trimmed) {
return nil, fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
}
m.Content[i] = NewTextPart(trimmed)
}
}
return m, nil
}
// Get enum strings from json schema
func objectEnums(schema map[string]any) []string {
var enums []string
if properties, ok := schema["properties"].(map[string]any); ok {
for _, propValue := range properties {
if propMap, ok := propValue.(map[string]any); ok {
if enumSlice, ok := propMap["enum"].([]any); ok {
for _, enumVal := range enumSlice {
if enumStr, ok := enumVal.(string); ok {
enums = append(enums, enumStr)
}
}
}
}
}
}
return enums
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"encoding/json"
"errors"
"fmt"
"github.com/firebase/genkit/go/internal/base"
)
type jsonFormatter struct{}
// Name returns the name of the formatter.
func (j jsonFormatter) Name() string {
return OutputFormatJSON
}
// Handler returns a new formatter handler for the given schema.
func (j jsonFormatter) Handler(schema map[string]any) (FormatHandler, error) {
var instructions string
if schema != nil {
jsonBytes, err := json.Marshal(schema)
if err != nil {
return nil, fmt.Errorf("error marshalling schema to JSON: %w", err)
}
instructions = fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", string(jsonBytes))
}
handler := &jsonHandler{
instructions: instructions,
config: ModelOutputConfig{
Format: OutputFormatJSON,
Schema: schema,
ContentType: "application/json",
},
}
return handler, nil
}
// jsonHandler is a handler for the JSON formatter.
type jsonHandler struct {
instructions string
config ModelOutputConfig
}
// Instructions returns the instructions for the formatter.
func (j jsonHandler) Instructions() string {
return j.instructions
}
// Config returns the output config for the formatter.
func (j jsonHandler) Config() ModelOutputConfig {
return j.config
}
// ParseMessage parses the message and returns the formatted message.
func (j jsonHandler) ParseMessage(m *Message) (*Message, error) {
if j.config.Format == OutputFormatJSON {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}
for i, part := range m.Content {
if !part.IsText() {
continue
}
text := base.ExtractJSONFromMarkdown(part.Text)
if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema)
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil {
return nil, err
}
} else {
if !base.ValidJSON(text) {
return nil, errors.New("message is not a valid JSON")
}
}
m.Content[i] = NewJSONPart(text)
}
}
return m, nil
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"encoding/json"
"errors"
"fmt"
"github.com/firebase/genkit/go/internal/base"
)
type jsonlFormatter struct{}
// Name returns the name of the formatter.
func (j jsonlFormatter) Name() string {
return OutputFormatJSONL
}
// Handler returns a new formatter handler for the given schema.
func (j jsonlFormatter) Handler(schema map[string]any) (FormatHandler, error) {
if schema == nil || !base.ValidateIsJSONArray(schema) {
return nil, fmt.Errorf("schema is not valid JSONL")
}
jsonBytes, err := json.Marshal(schema["items"])
if err != nil {
return nil, fmt.Errorf("error marshalling schema to JSONL: %w", err)
}
instructions := fmt.Sprintf("Output should be JSONL format, a sequence of JSON objects (one per line) separated by a newline '\\n' character. Each line should be a JSON object conforming to the following schema:\n\n```%s```", string(jsonBytes))
handler := &jsonlHandler{
instructions: instructions,
config: ModelOutputConfig{
Format: OutputFormatJSONL,
Schema: schema,
ContentType: "application/jsonl",
},
}
return handler, nil
}
type jsonlHandler struct {
instructions string
config ModelOutputConfig
}
// Instructions returns the instructions for the formatter.
func (j jsonlHandler) Instructions() string {
return j.instructions
}
// Config returns the output config for the formatter.
func (j jsonlHandler) Config() ModelOutputConfig {
return j.config
}
// ParseMessage parses the message and returns the formatted message.
func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) {
if j.config.Format == OutputFormatJSONL {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}
var newParts []*Part
for _, part := range m.Content {
if !part.IsText() {
newParts = append(newParts, part)
} else {
lines := base.GetJsonObjectLines(part.Text)
for _, line := range lines {
if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}
}
newParts = append(newParts, NewJSONPart(line))
}
}
}
m.Content = newParts
}
return m, nil
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
type textFormatter struct{}
// Name returns the name of the formatter.
func (t textFormatter) Name() string {
return OutputFormatText
}
// Handler returns a new formatter handler for the given schema.
func (t textFormatter) Handler(schema map[string]any) (FormatHandler, error) {
handler := &textHandler{
config: ModelOutputConfig{
ContentType: "text/plain",
},
}
return handler, nil
}
type textHandler struct {
instructions string
config ModelOutputConfig
}
// Config returns the output config for the formatter.
func (t textHandler) Config() ModelOutputConfig {
return t.config
}
// Instructions returns the instructions for the formatter.
func (t textHandler) Instructions() string {
return t.instructions
}
// ParseMessage parses the message and returns the formatted message.
func (t textHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import (
"fmt"
"github.com/firebase/genkit/go/internal/registry"
)
const (
OutputFormatText string = "text"
OutputFormatJSON string = "json"
OutputFormatJSONL string = "jsonl"
OutputFormatMedia string = "media"
OutputFormatArray string = "array"
OutputFormatEnum string = "enum"
)
// Default formats get automatically registered on registry init
var DEFAULT_FORMATS = []Formatter{
jsonFormatter{},
jsonlFormatter{},
textFormatter{},
arrayFormatter{},
enumFormatter{},
}
// Formatter represents the Formatter interface.
type Formatter interface {
// Name returns the name of the formatter.
Name() string
// Handler returns the handler for the formatter.
Handler(schema map[string]any) (FormatHandler, error)
}
// FormatHandler represents the handler part of the Formatter interface.
type FormatHandler interface {
// ParseMessage parses the message and returns a new formatted message.
ParseMessage(message *Message) (*Message, error)
// Instructions returns the formatter instructions to embed in the prompt.
Instructions() string
// Config returns the output config for the model request.
Config() ModelOutputConfig
}
// ConfigureFormats registers default formats in the registry
func ConfigureFormats(reg *registry.Registry) {
for _, format := range DEFAULT_FORMATS {
DefineFormat(reg, "/format/"+format.Name(), format)
}
}
// DefineFormat defines and registers a new [Formatter].
func DefineFormat(r *registry.Registry, name string, formatter Formatter) {
r.RegisterValue(name, formatter)
}
// resolveFormat returns a [Formatter], either a default one or one from the registry.
func resolveFormat(reg *registry.Registry, schema map[string]any, format string) (Formatter, error) {
var formatter any
// If schema is set but no explicit format is set we default to json.
if schema != nil && format == "" {
formatter = reg.LookupValue("/format/" + OutputFormatJSON)
}
// If format is not set we default to text
if format == "" {
formatter = reg.LookupValue("/format/" + OutputFormatText)
}
// Lookup format in registry
if format != "" {
formatter = reg.LookupValue("/format/" + format)
}
if f, ok := formatter.(Formatter); ok {
return f, nil
}
return nil, fmt.Errorf("output format %q is invalid", format)
}
// injectInstructions looks through the messages and injects formatting directives
func injectInstructions(messages []*Message, instructions string) []*Message {
if instructions == "" {
return messages
}
// bail out if an output part is already present
for _, m := range messages {
for _, p := range m.Content {
if p.Metadata != nil && p.Metadata["purpose"] == "output" {
return messages
}
}
}
part := NewTextPart(instructions)
part.Metadata = map[string]any{"purpose": "output"}
targetIndex := -1
// First try to find a system message
for i, m := range messages {
if m.Role == RoleSystem {
targetIndex = i
break
}
}
// If no system message, find the last user message
if targetIndex == -1 {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == RoleUser {
targetIndex = i
break
}
}
}
if targetIndex != -1 {
messages[targetIndex].Content = append(messages[targetIndex].Content, part)
}
return messages
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package core
import (
"context"
"encoding/json"
"github.com/firebase/genkit/go/internal/base"
)
var actionCtxKey = base.NewContextKey[int]()
// WithActionContext returns a new Context with Action runtime context (side channel data) value set.
func WithActionContext(ctx context.Context, actionCtx ActionContext) context.Context {
return context.WithValue(ctx, actionCtxKey, actionCtx)
}
// FromContext returns the Action runtime context (side channel data) from context.
func FromContext(ctx context.Context) ActionContext {
val := ctx.Value(actionCtxKey)
if val == nil {
return nil
}
return val.(ActionContext)
}
// ActionContext is the runtime context for an Action.
type ActionContext = map[string]any
// RequestData is the data associated with a request.
// It is used to provide additional context to the Action.
type RequestData struct {
Method string // Method is the HTTP method of the request (e.g. "GET", "POST", etc.)
Headers map[string]string // Headers is the headers of the request. The keys are the header names in lowercase.
Input json.RawMessage // Input is the body of the request.
}
// ContextProvider is a function that returns an ActionContext for a given request.
// It is used to provide additional context to the Action.
type ContextProvider = func(ctx context.Context, req RequestData) (ActionContext, error)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
-- source --
first
second
//copy:start dest foo
third
fourth
//copy:stop
fifth
//copy:start dest bar
sixth
//copy:stop
seventh
-- dest --
line1
//copy:sink bar
line2
//copy:sink foo
-- want --
line1
//copy:sink bar from source
// DO NOT MODIFY below vvvv
sixth
// DO NOT MODIFY above ^^^^
//copy:endsink bar
line2
//copy:sink foo from source
// DO NOT MODIFY below vvvv
third
fourth
// DO NOT MODIFY above ^^^^
//copy:endsink foo
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment