diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index c08ecceb8..ea5afd223 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -3,6 +3,7 @@ package modelsdev import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -23,6 +24,9 @@ const ( refreshInterval = 24 * time.Hour ) +// ErrProviderNotFound is returned when a requested provider is not found in the database. +var ErrProviderNotFound = errors.New("provider not found") + // Store manages access to the models.dev data. // All methods are safe for concurrent use. // @@ -91,7 +95,7 @@ func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider, provider, exists := db.Providers[providerID] if !exists { - return nil, fmt.Errorf("provider %q not found", providerID) + return nil, fmt.Errorf("%w: %q", ErrProviderNotFound, providerID) } return &provider, nil diff --git a/pkg/runtime/session_compaction.go b/pkg/runtime/session_compaction.go index afa015aed..f56449ad5 100644 --- a/pkg/runtime/session_compaction.go +++ b/pkg/runtime/session_compaction.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker-agent/pkg/compaction" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" ) @@ -26,7 +27,9 @@ const maxKeepTokens = 20_000 // persistence, token count updates). The agent is used to extract the // conversation from the session and to obtain the model for summarization. func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *agent.Agent, additionalPrompt string, events chan Event) { - slog.Debug("Generating summary for session", "session_id", sess.ID) + lg := slog.With("session_id", sess.ID, "agent", a.Name(), "action", "compaction") + + lg.Debug("Generating summary for session") events <- SessionCompaction(sess.ID, "started", a.Name()) defer func() { events <- SessionCompaction(sess.ID, "completed", a.Name()) @@ -37,10 +40,32 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a * options.WithStructuredOutput(nil), options.WithMaxTokens(maxSummaryTokens), ) + m, err := r.modelsStore.GetModel(ctx, summaryModel.ID()) + if err != nil && errors.Is(err, modelsdev.ErrProviderNotFound) { + lg.Debug("Provider not found; attempting to find by model name", "error", err) + + db, dberr := r.modelsStore.GetDatabase(ctx) + if dberr != nil { + lg.Error("Provider not found and failed to find by model name", "error", dberr) + events <- Error("Failed to get db to find model definition: " + dberr.Error()) + return + } + + // Find the lowest context limit for this model, regardless of the provider. + for _, provider := range db.Providers { + if v, ok := provider.Models[summaryModel.BaseConfig().ModelConfig.Model]; ok { + if m == nil || v.Limit.Context < m.Limit.Context { + m = &v + err = nil + } + } + } + } + if err != nil { - slog.Error("Failed to generate session summary", "error", errors.New("failed to get model definition")) - events <- Error("Failed to get model definition") + lg.Error("Failed to get model definition to generate session summary", "error", err) + events <- Error("Failed to get model definition: " + err.Error()) return } @@ -58,12 +83,12 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a * t := team.New(team.WithAgents(compactionAgent)) rt, err := New(t, WithSessionCompaction(false)) if err != nil { - slog.Error("Failed to generate session summary", "error", err) + lg.Error("Failed to generate session summary", "error", err) events <- Error(err.Error()) return } if _, err = rt.Run(ctx, compactionSession); err != nil { - slog.Error("Failed to generate session summary", "error", err) + lg.Error("Failed to generate session summary", "error", err) events <- Error(err.Error()) return } @@ -83,7 +108,7 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a * }) _ = r.sessionStore.UpdateSession(ctx, sess) - slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary)) + lg.Debug("Generated session summary", "summary_length", len(summary)) events <- SessionSummary(sess.ID, summary, a.Name(), firstKeptEntry) }