diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index 4eb4469673..2035718887 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "slices" "strings" @@ -56,8 +57,52 @@ var ( ai.RoleSystem: "system", ai.RoleTool: "tool", } + // defaultOllamaSupports defines the default capabilities for dynamically + // discovered Ollama models. All capabilities are enabled since local models + // vary widely and we can't query their capabilities individually. + defaultOllamaSupports = ai.ModelSupports{ + Multiturn: true, + Media: true, + Tools: true, + SystemRole: true, + } ) +// ollamaTagsResponse represents the response from GET /api/tags. +type ollamaTagsResponse struct { + Models []ollamaLocalModel `json:"models"` +} + +// ollamaLocalModel represents a locally available Ollama model. +type ollamaLocalModel struct { + Name string `json:"name"` + Model string `json:"model"` +} + +// listLocalModels calls GET /api/tags to list locally installed Ollama models. +func listLocalModels(ctx context.Context, serverAddress string) ([]ollamaLocalModel, error) { + req, err := http.NewRequestWithContext(ctx, "GET", serverAddress+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch local models from Ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ollama /api/tags returned status %d", resp.StatusCode) + } + + var tagsResp ollamaTagsResponse + if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil { + return nil, fmt.Errorf("failed to decode /api/tags response: %w", err) + } + return tagsResp.Models, nil +} + func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, opts *ai.ModelOptions) ai.Model { o.mu.Lock() defer o.mu.Unlock() @@ -226,6 +271,57 @@ func (o *Ollama) Init(ctx context.Context) []api.Action { return []api.Action{} } +// newModel creates an Ollama model without registering it in the Genkit registry. +// It is used by ListActions (to generate ActionDesc) and ResolveAction (to return an Action). +func (o *Ollama) newModel(name string, opts ai.ModelOptions) ai.Model { + meta := &ai.ModelOptions{ + Label: "Ollama - " + name, + Supports: opts.Supports, + Versions: []string{}, + } + gen := &generator{ + model: ModelDefinition{Name: name, Type: "chat"}, + serverAddress: o.ServerAddress, + timeout: o.Timeout, + } + return ai.NewModel(api.NewName(provider, name), meta, gen.generate) +} + +// ListActions calls /api/tags to discover locally installed Ollama models. +func (o *Ollama) ListActions(ctx context.Context) []api.ActionDesc { + models, err := listLocalModels(ctx, o.ServerAddress) + if err != nil { + slog.Error("unable to list ollama models", "error", err) + return nil + } + + var actions []api.ActionDesc + for _, m := range models { + name := m.Name + // Filter out embedding models (following JS: !m.model.includes('embed')) + if strings.Contains(name, "embed") { + continue + } + model := o.newModel(name, ai.ModelOptions{Supports: &defaultOllamaSupports}) + if action, ok := model.(api.Action); ok { + actions = append(actions, action.Desc()) + } + } + return actions +} + +// ResolveAction dynamically creates a model action on demand. +func (o *Ollama) ResolveAction(atype api.ActionType, name string) api.Action { + if atype != api.ActionTypeModel { + return nil + } + model := o.newModel(name, ai.ModelOptions{Supports: &defaultOllamaSupports}) + if action, ok := model.(api.Action); ok { + return action + } + return nil +} + // Generate makes a request to the Ollama API and processes the response. func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { stream := cb != nil diff --git a/go/plugins/ollama/ollama_live_test.go b/go/plugins/ollama/ollama_live_test.go index 1c16886e02..377464baa2 100644 --- a/go/plugins/ollama/ollama_live_test.go +++ b/go/plugins/ollama/ollama_live_test.go @@ -28,6 +28,7 @@ import ( var serverAddress = flag.String("server-address", "http://localhost:11434", "Ollama server address") var modelName = flag.String("model-name", "tinyllama", "model name") +var dynamicModelName = flag.String("dynamic-model-name", "moondream", "model name for dynamic discovery test (must not be in hardcoded lists)") var testLive = flag.Bool("test-live", false, "run live tests") /* @@ -77,3 +78,48 @@ func TestLive(t *testing.T) { t.Fatalf("expected non-empty response, got: %s", text) } } + +// TestLiveDynamicDiscovery verifies that a model NOT registered via DefineModel +// can be discovered and used through the DynamicPlugin interface (ListActions + ResolveAction). +func TestLiveDynamicDiscovery(t *testing.T) { + if !*testLive { + t.Skip("skipping go/plugins/ollama live dynamic discovery test") + } + + ctx := context.Background() + o := &ollamaPlugin.Ollama{ServerAddress: *serverAddress} + g := genkit.Init(ctx, genkit.WithPlugins(o)) + + // Verify ListActions discovers local models + actions := o.ListActions(ctx) + if len(actions) == 0 { + t.Fatal("ListActions() returned no actions, ensure Ollama has local models") + } + t.Logf("ListActions() discovered %d models:", len(actions)) + for _, a := range actions { + t.Logf(" - %s", a.Name) + } + + // Use a model that is NOT in the hardcoded lists via LookupModel, + // which triggers ResolveAction under the hood. + m := ollamaPlugin.Model(g, *dynamicModelName) + if m == nil { + t.Fatalf("Model(%q) returned nil — ResolveAction did not work", *dynamicModelName) + } + + // Generate a response from the dynamically resolved model + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 1}), + ai.WithPrompt("Say hello in one sentence."), + ) + if err != nil { + t.Fatalf("failed to generate with dynamic model %q: %s", *dynamicModelName, err) + } + + text := resp.Text() + t.Logf("Dynamic model %q response: %s", *dynamicModelName, text) + if text == "" { + t.Fatalf("expected non-empty response from dynamic model %q", *dynamicModelName) + } +} diff --git a/go/plugins/ollama/ollama_test.go b/go/plugins/ollama/ollama_test.go index 1e41299e85..28a2d157f4 100644 --- a/go/plugins/ollama/ollama_test.go +++ b/go/plugins/ollama/ollama_test.go @@ -17,6 +17,10 @@ package ollama import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/firebase/genkit/go/ai" @@ -24,6 +28,7 @@ import ( ) var _ api.Plugin = (*Ollama)(nil) +var _ api.DynamicPlugin = (*Ollama)(nil) func TestConcatMessages(t *testing.T) { tests := []struct { @@ -131,3 +136,155 @@ func equalContent(a, b []*ai.Part) bool { } return true } + +func newTestOllama(serverAddress string) *Ollama { + o := &Ollama{ServerAddress: serverAddress, Timeout: 30} + o.Init(context.Background()) + return o +} + +func TestDynamicPlugin(t *testing.T) { + t.Run("listLocalModels", func(t *testing.T) { + tests := []struct { + name string + response ollamaTagsResponse + statusCode int + wantCount int + wantErr bool + }{ + { + name: "successful response with multiple models", + response: ollamaTagsResponse{ + Models: []ollamaLocalModel{ + {Name: "llama3:latest", Model: "llama3:latest"}, + {Name: "mistral:7b", Model: "mistral:7b"}, + }, + }, + statusCode: http.StatusOK, + wantCount: 2, + }, + { + name: "empty model list", + response: ollamaTagsResponse{Models: []ollamaLocalModel{}}, + statusCode: http.StatusOK, + wantCount: 0, + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tags" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Method != http.MethodGet { + t.Errorf("unexpected method: %s", r.Method) + } + w.WriteHeader(tt.statusCode) + if tt.statusCode == http.StatusOK { + json.NewEncoder(w).Encode(tt.response) + } + })) + defer server.Close() + + models, err := listLocalModels(context.Background(), server.URL) + if (err != nil) != tt.wantErr { + t.Errorf("listLocalModels() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && len(models) != tt.wantCount { + t.Errorf("listLocalModels() returned %d models, want %d", len(models), tt.wantCount) + } + }) + } + }) + + t.Run("ListActions", func(t *testing.T) { + t.Run("filters embed models", func(t *testing.T) { + response := ollamaTagsResponse{ + Models: []ollamaLocalModel{ + {Name: "llama3:latest", Model: "llama3:latest"}, + {Name: "nomic-embed-text:latest", Model: "nomic-embed-text:latest"}, + {Name: "moondream:v2", Model: "moondream:v2"}, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + o := newTestOllama(server.URL) + actions := o.ListActions(context.Background()) + + if len(actions) != 2 { + t.Fatalf("ListActions() returned %d actions, want 2", len(actions)) + } + + names := make(map[string]bool) + for _, a := range actions { + names[a.Name] = true + } + if !names["ollama/llama3:latest"] { + t.Error("ListActions() missing ollama/llama3:latest") + } + if !names["ollama/moondream:v2"] { + t.Error("ListActions() missing ollama/moondream:v2") + } + if names["ollama/nomic-embed-text:latest"] { + t.Error("ListActions() should have filtered out embed model") + } + }) + + t.Run("server unreachable", func(t *testing.T) { + o := newTestOllama("http://localhost:0") + actions := o.ListActions(context.Background()) + if actions != nil { + t.Errorf("ListActions() should return nil when server is unreachable, got %v", actions) + } + }) + }) + + t.Run("ResolveAction", func(t *testing.T) { + o := newTestOllama("http://localhost:11434") + + t.Run("model action type", func(t *testing.T) { + action := o.ResolveAction(api.ActionTypeModel, "llama3:latest") + if action == nil { + t.Fatal("ResolveAction() returned nil for model type") + } + desc := action.Desc() + if desc.Name != "ollama/llama3:latest" { + t.Errorf("ResolveAction() name = %q, want %q", desc.Name, "ollama/llama3:latest") + } + }) + + t.Run("non-model action type", func(t *testing.T) { + action := o.ResolveAction(api.ActionTypeExecutablePrompt, "llama3:latest") + if action != nil { + t.Error("ResolveAction() should return nil for non-model action type") + } + }) + }) + + t.Run("newModel", func(t *testing.T) { + o := newTestOllama("http://localhost:11434") + model := o.newModel("test-model", ai.ModelOptions{Supports: &defaultOllamaSupports}) + if model == nil { + t.Fatal("newModel() returned nil") + } + action, ok := model.(api.Action) + if !ok { + t.Fatal("newModel() result does not implement api.Action") + } + desc := action.Desc() + if desc.Name != "ollama/test-model" { + t.Errorf("newModel() name = %q, want %q", desc.Name, "ollama/test-model") + } + }) +}