Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"slices"
"strings"
Expand Down Expand Up @@ -56,8 +57,52 @@ var (
ai.RoleSystem: "system",
ai.RoleTool: "tool",
}
// defaultOllamaSupports defines the default capabilities for dynamically
// discovered Ollama models. All capabilities are enabled since local models
// vary widely and we can't query their capabilities individually.
defaultOllamaSupports = ai.ModelSupports{
Multiturn: true,
Media: true,
Tools: true,
SystemRole: true,
}
)

// ollamaTagsResponse represents the response from GET /api/tags.
type ollamaTagsResponse struct {
Models []ollamaLocalModel `json:"models"`
}

// ollamaLocalModel represents a locally available Ollama model.
type ollamaLocalModel struct {
Name string `json:"name"`
Model string `json:"model"`
}

// listLocalModels calls GET /api/tags to list locally installed Ollama models.
func listLocalModels(ctx context.Context, serverAddress string) ([]ollamaLocalModel, error) {
req, err := http.NewRequestWithContext(ctx, "GET", serverAddress+"/api/tags", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch local models from Ollama: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama /api/tags returned status %d", resp.StatusCode)
}

var tagsResp ollamaTagsResponse
if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil {
return nil, fmt.Errorf("failed to decode /api/tags response: %w", err)
}
return tagsResp.Models, nil
}

func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, opts *ai.ModelOptions) ai.Model {
o.mu.Lock()
defer o.mu.Unlock()
Expand Down Expand Up @@ -226,6 +271,57 @@ func (o *Ollama) Init(ctx context.Context) []api.Action {
return []api.Action{}
}

// newModel creates an Ollama model without registering it in the Genkit registry.
// It is used by ListActions (to generate ActionDesc) and ResolveAction (to return an Action).
func (o *Ollama) newModel(name string, opts ai.ModelOptions) ai.Model {
meta := &ai.ModelOptions{
Label: "Ollama - " + name,
Supports: opts.Supports,
Versions: []string{},
}
gen := &generator{
model: ModelDefinition{Name: name, Type: "chat"},
serverAddress: o.ServerAddress,
timeout: o.Timeout,
}
return ai.NewModel(api.NewName(provider, name), meta, gen.generate)
}

// ListActions calls /api/tags to discover locally installed Ollama models.
func (o *Ollama) ListActions(ctx context.Context) []api.ActionDesc {
models, err := listLocalModels(ctx, o.ServerAddress)
if err != nil {
slog.Error("unable to list ollama models", "error", err)
return nil
}

var actions []api.ActionDesc
for _, m := range models {
name := m.Name
// Filter out embedding models (following JS: !m.model.includes('embed'))
if strings.Contains(name, "embed") {
continue
}
model := o.newModel(name, ai.ModelOptions{Supports: &defaultOllamaSupports})
if action, ok := model.(api.Action); ok {
actions = append(actions, action.Desc())
}
}
return actions
}

// ResolveAction dynamically creates a model action on demand.
func (o *Ollama) ResolveAction(atype api.ActionType, name string) api.Action {
if atype != api.ActionTypeModel {
return nil
}
model := o.newModel(name, ai.ModelOptions{Supports: &defaultOllamaSupports})
if action, ok := model.(api.Action); ok {
return action
}
return nil
}

// Generate makes a request to the Ollama API and processes the response.
func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
stream := cb != nil
Expand Down
46 changes: 46 additions & 0 deletions go/plugins/ollama/ollama_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

var serverAddress = flag.String("server-address", "http://localhost:11434", "Ollama server address")
var modelName = flag.String("model-name", "tinyllama", "model name")
var dynamicModelName = flag.String("dynamic-model-name", "moondream", "model name for dynamic discovery test (must not be in hardcoded lists)")
var testLive = flag.Bool("test-live", false, "run live tests")

/*
Expand Down Expand Up @@ -77,3 +78,48 @@ func TestLive(t *testing.T) {
t.Fatalf("expected non-empty response, got: %s", text)
}
}

// TestLiveDynamicDiscovery verifies that a model NOT registered via DefineModel
// can be discovered and used through the DynamicPlugin interface (ListActions + ResolveAction).
func TestLiveDynamicDiscovery(t *testing.T) {
if !*testLive {
t.Skip("skipping go/plugins/ollama live dynamic discovery test")
}

ctx := context.Background()
o := &ollamaPlugin.Ollama{ServerAddress: *serverAddress}
g := genkit.Init(ctx, genkit.WithPlugins(o))

// Verify ListActions discovers local models
actions := o.ListActions(ctx)
if len(actions) == 0 {
t.Fatal("ListActions() returned no actions, ensure Ollama has local models")
}
t.Logf("ListActions() discovered %d models:", len(actions))
for _, a := range actions {
t.Logf(" - %s", a.Name)
}

// Use a model that is NOT in the hardcoded lists via LookupModel,
// which triggers ResolveAction under the hood.
m := ollamaPlugin.Model(g, *dynamicModelName)
if m == nil {
t.Fatalf("Model(%q) returned nil — ResolveAction did not work", *dynamicModelName)
}

// Generate a response from the dynamically resolved model
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 1}),
ai.WithPrompt("Say hello in one sentence."),
)
if err != nil {
t.Fatalf("failed to generate with dynamic model %q: %s", *dynamicModelName, err)
}

text := resp.Text()
t.Logf("Dynamic model %q response: %s", *dynamicModelName, text)
if text == "" {
t.Fatalf("expected non-empty response from dynamic model %q", *dynamicModelName)
}
}
157 changes: 157 additions & 0 deletions go/plugins/ollama/ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
package ollama

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core/api"
)

var _ api.Plugin = (*Ollama)(nil)
var _ api.DynamicPlugin = (*Ollama)(nil)

func TestConcatMessages(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -131,3 +136,155 @@ func equalContent(a, b []*ai.Part) bool {
}
return true
}

func newTestOllama(serverAddress string) *Ollama {
o := &Ollama{ServerAddress: serverAddress, Timeout: 30}
o.Init(context.Background())
return o
}

func TestDynamicPlugin(t *testing.T) {
t.Run("listLocalModels", func(t *testing.T) {
tests := []struct {
name string
response ollamaTagsResponse
statusCode int
wantCount int
wantErr bool
}{
{
name: "successful response with multiple models",
response: ollamaTagsResponse{
Models: []ollamaLocalModel{
{Name: "llama3:latest", Model: "llama3:latest"},
{Name: "mistral:7b", Model: "mistral:7b"},
},
},
statusCode: http.StatusOK,
wantCount: 2,
},
{
name: "empty model list",
response: ollamaTagsResponse{Models: []ollamaLocalModel{}},
statusCode: http.StatusOK,
wantCount: 0,
},
{
name: "server error",
statusCode: http.StatusInternalServerError,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Method != http.MethodGet {
t.Errorf("unexpected method: %s", r.Method)
}
w.WriteHeader(tt.statusCode)
if tt.statusCode == http.StatusOK {
json.NewEncoder(w).Encode(tt.response)
}
}))
defer server.Close()

models, err := listLocalModels(context.Background(), server.URL)
if (err != nil) != tt.wantErr {
t.Errorf("listLocalModels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(models) != tt.wantCount {
t.Errorf("listLocalModels() returned %d models, want %d", len(models), tt.wantCount)
}
})
}
})

t.Run("ListActions", func(t *testing.T) {
t.Run("filters embed models", func(t *testing.T) {
response := ollamaTagsResponse{
Models: []ollamaLocalModel{
{Name: "llama3:latest", Model: "llama3:latest"},
{Name: "nomic-embed-text:latest", Model: "nomic-embed-text:latest"},
{Name: "moondream:v2", Model: "moondream:v2"},
},
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(response)
}))
defer server.Close()

o := newTestOllama(server.URL)
actions := o.ListActions(context.Background())

if len(actions) != 2 {
t.Fatalf("ListActions() returned %d actions, want 2", len(actions))
}

names := make(map[string]bool)
for _, a := range actions {
names[a.Name] = true
}
if !names["ollama/llama3:latest"] {
t.Error("ListActions() missing ollama/llama3:latest")
}
if !names["ollama/moondream:v2"] {
t.Error("ListActions() missing ollama/moondream:v2")
}
if names["ollama/nomic-embed-text:latest"] {
t.Error("ListActions() should have filtered out embed model")
}
})

t.Run("server unreachable", func(t *testing.T) {
o := newTestOllama("http://localhost:0")
actions := o.ListActions(context.Background())
if actions != nil {
t.Errorf("ListActions() should return nil when server is unreachable, got %v", actions)
}
})
})

t.Run("ResolveAction", func(t *testing.T) {
o := newTestOllama("http://localhost:11434")

t.Run("model action type", func(t *testing.T) {
action := o.ResolveAction(api.ActionTypeModel, "llama3:latest")
if action == nil {
t.Fatal("ResolveAction() returned nil for model type")
}
desc := action.Desc()
if desc.Name != "ollama/llama3:latest" {
t.Errorf("ResolveAction() name = %q, want %q", desc.Name, "ollama/llama3:latest")
}
})

t.Run("non-model action type", func(t *testing.T) {
action := o.ResolveAction(api.ActionTypeExecutablePrompt, "llama3:latest")
if action != nil {
t.Error("ResolveAction() should return nil for non-model action type")
}
})
})

t.Run("newModel", func(t *testing.T) {
o := newTestOllama("http://localhost:11434")
model := o.newModel("test-model", ai.ModelOptions{Supports: &defaultOllamaSupports})
if model == nil {
t.Fatal("newModel() returned nil")
}
action, ok := model.(api.Action)
if !ok {
t.Fatal("newModel() result does not implement api.Action")
}
desc := action.Desc()
if desc.Name != "ollama/test-model" {
t.Errorf("newModel() name = %q, want %q", desc.Name, "ollama/test-model")
}
})
}
Loading