diff --git a/bridges.manifest.yml b/bridges.manifest.yml index cfabd70d..77c35d87 100644 --- a/bridges.manifest.yml +++ b/bridges.manifest.yml @@ -15,6 +15,25 @@ instances: bridge.permissions: "*": relay beeper.com: admin + network.agents.enabled: false + + agent: + bridge_type: ai + mode: local-repo + repo_path: . + build_cmd: ./build.sh + binary_path: ./ai + beeper_bridge_name: sh-agent + config_overrides: + appservice.address: websocket + appservice.hostname: 127.0.0.1 + appservice.port: 29350 + database.type: sqlite3-fk-wal + database.uri: file:agent.db?_txlock=immediate + bridge.permissions: + "*": relay + beeper.com: admin + network.agents.enabled: true codex: bridge_type: codex diff --git a/bridges/ai/agent_activity.go b/bridges/ai/agent_activity.go index 69f2703f..bee8f278 100644 --- a/bridges/ai/agent_activity.go +++ b/bridges/ai/agent_activity.go @@ -21,6 +21,9 @@ func (oc *AIClient) recordAgentActivity(ctx context.Context, portal *bridgev2.Po if heartbeatRunFromContext(ctx) != nil { return } + if oc.agentTargetBlocked(meta) { + return + } agentID := normalizeAgentID(resolveAgentID(meta)) if agentID == "" { return @@ -93,12 +96,12 @@ func (oc *AIClient) defaultChatPortal() *bridgev2.Portal { Receiver: oc.UserLogin.ID, } if portal, err := oc.UserLogin.Bridge.GetPortalByKey(ctx, portalKey); err == nil && portal != nil { - if isDefaultChatCandidate(portal) { + if oc.isDefaultChatCandidate(portal) { return portal } } } - if portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultChatPortalKey(oc.UserLogin.ID)); err == nil && portal != nil && isDefaultChatCandidate(portal) { + if portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultChatPortalKey(oc.UserLogin.ID)); err == nil && portal != nil && oc.isDefaultChatCandidate(portal) { return portal } return nil diff --git a/bridges/ai/agent_mode.go b/bridges/ai/agent_mode.go new file mode 100644 index 00000000..2efddb1a --- /dev/null +++ b/bridges/ai/agent_mode.go @@ -0,0 +1,80 @@ +package ai + +import ( + "errors" + + "maunium.net/go/mautrix/bridgev2" +) + +var errAgentsDisabled = errors.New("agents are disabled by bridge config") + +func (c *Config) agentsEnabled() bool { + if c == nil || c.Agents == nil || c.Agents.Enabled == nil { + return true + } + return *c.Agents.Enabled +} + +func (oc *OpenAIConnector) agentsEnabled() bool { + if oc == nil { + return true + } + return oc.Config.agentsEnabled() +} + +func (oc *AIClient) agentsEnabled() bool { + if oc == nil || oc.connector == nil { + return true + } + return oc.connector.agentsEnabled() +} + +func (oc *AIClient) agentFeaturesDisabledErr() error { + return errAgentsDisabled +} + +func (oc *AIClient) agentTargetBlocked(meta *PortalMetadata) bool { + return oc != nil && !oc.agentsEnabled() && resolveAgentID(meta) != "" +} + +func (oc *AIClient) ensureAgentTargetAllowed(meta *PortalMetadata) error { + if oc.agentTargetBlocked(meta) { + return oc.agentFeaturesDisabledErr() + } + return nil +} + +func (oc *AIClient) shouldExcludeVisiblePortal(meta *PortalMetadata) bool { + if shouldExcludeModelVisiblePortal(meta) { + return true + } + return oc.agentTargetBlocked(meta) +} + +func (oc *AIClient) isDefaultChatCandidate(portal *bridgev2.Portal) bool { + return portal != nil && !oc.shouldExcludeVisiblePortal(portalMeta(portal)) +} + +func (oc *AIClient) chooseDefaultChatPortal(portals []*bridgev2.Portal) *bridgev2.Portal { + var defaultPortal *bridgev2.Portal + var ( + minIdx int + haveSlug bool + ) + for _, portal := range portals { + if !oc.isDefaultChatCandidate(portal) { + continue + } + pm := portalMeta(portal) + if idx, ok := parseChatSlug(pm.Slug); ok { + if !haveSlug || idx < minIdx { + minIdx = idx + defaultPortal = portal + haveSlug = true + } + } else if defaultPortal == nil && !haveSlug { + defaultPortal = portal + } + } + return defaultPortal +} diff --git a/bridges/ai/chat.go b/bridges/ai/chat.go index 3d043857..bde4c6bf 100644 --- a/bridges/ai/chat.go +++ b/bridges/ai/chat.go @@ -409,6 +409,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho return resp.Chat, nil } if agentID, ok := parseAgentFromGhostID(ghostID); ok { + if !oc.agentsEnabled() { + return nil, bridgev2.WrapRespErr(oc.agentFeaturesDisabledErr(), mautrix.MForbidden) + } store := NewAgentStoreAdapter(oc) agent, err := store.GetAgentByID(ctx, agentID) if err != nil || agent == nil { @@ -425,6 +428,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho // resolveAgentIdentifier resolves an agent to a ghost and optionally creates a chat. func (oc *AIClient) resolveAgentIdentifier(ctx context.Context, agent *agents.AgentDefinition, modelID string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + if !oc.agentsEnabled() { + return nil, oc.agentFeaturesDisabledErr() + } explicitModel := modelID != "" if modelID == "" { modelID = oc.agentDefaultModel(agent) @@ -730,6 +736,9 @@ func (oc *AIClient) resolveNewChatTarget( if cmd != "agent" { return nil, "", errors.New(usage) } + if !oc.agentsEnabled() { + return nil, "", oc.agentFeaturesDisabledErr() + } targetID := args[1] if targetID == "" || len(args) > 2 { return nil, "", errors.New(usage) @@ -753,6 +762,9 @@ func (oc *AIClient) resolveNewChatTarget( } agentID := resolveAgentID(meta) if agentID != "" { + if !oc.agentsEnabled() { + return nil, "", oc.agentFeaturesDisabledErr() + } store := NewAgentStoreAdapter(oc) agent, err := store.GetAgentByID(ctx, agentID) if err != nil || agent == nil { @@ -1093,7 +1105,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { if err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to load default chat portal by ID") } else if portal != nil { - if !isDefaultChatCandidate(portal) { + if !oc.isDefaultChatCandidate(portal) { deterministicPortalBlocked = portal.PortalKey == defaultPortalKey oc.loggerForContext(ctx).Warn().Stringer("portal", portal.PortalKey).Msg("Ignoring hidden portal stored as default chat") loginMeta.DefaultChatPortalID = "" @@ -1121,7 +1133,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultPortalKey) if err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to load default chat portal by deterministic key") - } else if portal != nil && isDefaultChatCandidate(portal) { + } else if portal != nil && oc.isDefaultChatCandidate(portal) { return oc.ensureExistingChatPortalReady(ctx, loginMeta, portal, "Existing default chat already has MXID", "Default chat missing MXID; creating Matrix room", "Failed to create Matrix room for default chat") } else if portal != nil { deterministicPortalBlocked = true @@ -1135,23 +1147,13 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { return err } - defaultPortal := chooseDefaultChatPortal(portals) + defaultPortal := oc.chooseDefaultChatPortal(portals) if defaultPortal != nil { return oc.ensureExistingChatPortalReady(ctx, loginMeta, defaultPortal, "Existing chat already has MXID", "Existing portal missing MXID; creating Matrix room", "Failed to create Matrix room for existing portal") } - // Create default chat with Beep agent - beeperAgent := agents.GetBeeperAI() - if beeperAgent == nil { - return errors.New("beeper AI agent not found") - } - - // Determine model from agent config or use default - modelID := beeperAgent.Model.Primary - if modelID == "" { - modelID = oc.effectiveModel(nil) - } + modelID := oc.effectiveModel(nil) initOpts := PortalInitOpts{ ModelID: modelID, @@ -1186,23 +1188,26 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { return err } - // Set agent-specific metadata - pm := portalMeta(portal) + if oc.agentsEnabled() { + beeperAgent := agents.GetBeeperAI() + if beeperAgent == nil { + return errors.New("beeper AI agent not found") + } - // Update the OtherUserID to be the agent ghost - agentGhostID := oc.agentUserID(beeperAgent.ID) - portal.OtherUserID = agentGhostID - pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID) + pm := portalMeta(portal) + agentGhostID := oc.agentUserID(beeperAgent.ID) + portal.OtherUserID = agentGhostID + pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID) - if err := portal.Save(ctx); err != nil { - oc.loggerForContext(ctx).Err(err).Msg("Failed to save portal with agent config") - return err - } + if err := portal.Save(ctx); err != nil { + oc.loggerForContext(ctx).Err(err).Msg("Failed to save portal with agent config") + return err + } - // Update chat info members to use agent ghost only - agentName := oc.resolveAgentDisplayName(ctx, beeperAgent) - oc.applyAgentChatInfo(chatInfo, beeperAgent.ID, agentName, modelID) - oc.ensureAgentGhostDisplayName(ctx, beeperAgent.ID, modelID, agentName) + agentName := oc.resolveAgentDisplayName(ctx, beeperAgent) + oc.applyAgentChatInfo(chatInfo, beeperAgent.ID, agentName, modelID) + oc.ensureAgentGhostDisplayName(ctx, beeperAgent.ID, modelID, agentName) + } loginMeta.DefaultChatPortalID = string(portal.PortalKey.ID) if err := oc.UserLogin.Save(ctx); err != nil { @@ -1218,7 +1223,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { } func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta *UserLoginMetadata, portal *bridgev2.Portal, readyMsg string, createMsg string, errMsg string) error { - if !isDefaultChatCandidate(portal) { + if !oc.isDefaultChatCandidate(portal) { return fmt.Errorf("portal %s is hidden and can't be selected as default chat", portal.PortalKey) } if loginMeta != nil { @@ -1241,34 +1246,6 @@ func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta return nil } -func isDefaultChatCandidate(portal *bridgev2.Portal) bool { - return portal != nil && !shouldExcludeModelVisiblePortal(portalMeta(portal)) -} - -func chooseDefaultChatPortal(portals []*bridgev2.Portal) *bridgev2.Portal { - var defaultPortal *bridgev2.Portal - var ( - minIdx int - haveSlug bool - ) - for _, portal := range portals { - if !isDefaultChatCandidate(portal) { - continue - } - pm := portalMeta(portal) - if idx, ok := parseChatSlug(pm.Slug); ok { - if !haveSlug || idx < minIdx { - minIdx = idx - defaultPortal = portal - haveSlug = true - } - } else if defaultPortal == nil && !haveSlug { - defaultPortal = portal - } - } - return defaultPortal -} - func (oc *AIClient) listAllChatPortals(ctx context.Context) ([]*bridgev2.Portal, error) { // Query all portals and filter by receiver (our login ID) // This works because all our portals have Receiver set to our UserLogin.ID diff --git a/bridges/ai/chat_login_redirect_test.go b/bridges/ai/chat_login_redirect_test.go index 9db0d677..53c119d9 100644 --- a/bridges/ai/chat_login_redirect_test.go +++ b/bridges/ai/chat_login_redirect_test.go @@ -7,6 +7,15 @@ import ( "testing" ) +func newDiscoveryTestClient(agentsEnabled bool) *AIClient { + client := newCatalogTestClient() + if !agentsEnabled { + disabled := false + client.connector.Config.Agents = &AgentsConfig{Enabled: &disabled} + } + return client +} + func TestSearchUsersRequiresLogin(t *testing.T) { oc := &AIClient{} _, err := oc.SearchUsers(context.Background(), "gpt") @@ -29,6 +38,40 @@ func TestGetContactListRequiresLogin(t *testing.T) { } } +func TestSearchUsersDisabledKeepsModelsButHidesAgents(t *testing.T) { + oc := newDiscoveryTestClient(false) + + results, err := oc.SearchUsers(context.Background(), "gpt") + if err != nil { + t.Fatalf("SearchUsers returned error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected model search results when agents are disabled") + } + for _, result := range results { + if result != nil && strings.HasPrefix(string(result.UserID), "agent-") { + t.Fatalf("expected agent results to be hidden, got %#v", result) + } + } +} + +func TestGetContactListDisabledKeepsModelsButHidesAgents(t *testing.T) { + oc := newDiscoveryTestClient(false) + + results, err := oc.GetContactList(context.Background()) + if err != nil { + t.Fatalf("GetContactList returned error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected model contacts when agents are disabled") + } + for _, result := range results { + if result != nil && strings.HasPrefix(string(result.UserID), "agent-") { + t.Fatalf("expected agent contacts to be hidden, got %#v", result) + } + } +} + func TestModelRedirectTarget(t *testing.T) { tests := []struct { name string diff --git a/bridges/ai/config_test.go b/bridges/ai/config_test.go index 7c284641..c4063cb0 100644 --- a/bridges/ai/config_test.go +++ b/bridges/ai/config_test.go @@ -77,3 +77,24 @@ func TestInboundConfig_WithDefaults_PartialValues(t *testing.T) { t.Errorf("Expected default DefaultDebounceMs %d, got %d", DefaultDebounceMs, result.DefaultDebounceMs) } } + +func TestConfigAgentsEnabledDefaultsToTrue(t *testing.T) { + if !(new(Config)).agentsEnabled() { + t.Fatal("expected agents to be enabled by default") + } + if (&Config{Agents: &AgentsConfig{}}).agentsEnabled() == false { + t.Fatal("expected missing agents.enabled to default to true") + } +} + +func TestConfigAgentsEnabledCanBeDisabled(t *testing.T) { + disabled := false + cfg := &Config{ + Agents: &AgentsConfig{ + Enabled: &disabled, + }, + } + if cfg.agentsEnabled() { + t.Fatal("expected agents to be disabled") + } +} diff --git a/bridges/ai/connector.go b/bridges/ai/connector.go index 37133757..b1383b47 100644 --- a/bridges/ai/connector.go +++ b/bridges/ai/connector.go @@ -2,7 +2,6 @@ package ai import ( "context" - "fmt" "slices" "strings" "sync" @@ -76,6 +75,9 @@ func (oc *OpenAIConnector) ValidateUserID(id networkid.UserID) bool { if modelID := parseModelFromGhostID(string(id)); strings.TrimSpace(modelID) != "" { return resolveModelIDFromManifest(modelID) != "" } + if !oc.agentsEnabled() { + return false + } if agentID, ok := parseAgentFromGhostID(string(id)); ok && isValidAgentID(strings.TrimSpace(agentID)) { return true } @@ -98,7 +100,7 @@ func (oc *OpenAIConnector) getLoginFlows() []bridgev2.LoginFlow { func (oc *OpenAIConnector) createLogin(ctx context.Context, user *bridgev2.User, flowID string) (bridgev2.LoginProcess, error) { flows := oc.getLoginFlows() if !slices.ContainsFunc(flows, func(f bridgev2.LoginFlow) bool { return f.ID == flowID }) { - return nil, fmt.Errorf("login flow %s is not available", flowID) + return nil, bridgev2.ErrInvalidLoginFlowID } return &OpenAILogin{User: user, Connector: oc, FlowID: flowID}, nil } diff --git a/bridges/ai/default_chat_test.go b/bridges/ai/default_chat_test.go index 80eeab7f..2bc64a86 100644 --- a/bridges/ai/default_chat_test.go +++ b/bridges/ai/default_chat_test.go @@ -9,6 +9,7 @@ import ( ) func TestChooseDefaultChatPortalSkipsHiddenRooms(t *testing.T) { + client := &AIClient{} hidden := &bridgev2.Portal{ Portal: &database.Portal{ PortalKey: networkid.PortalKey{ID: "openai:hidden"}, @@ -27,8 +28,38 @@ func TestChooseDefaultChatPortalSkipsHiddenRooms(t *testing.T) { }, } - selected := chooseDefaultChatPortal([]*bridgev2.Portal{hidden, visible}) + selected := client.chooseDefaultChatPortal([]*bridgev2.Portal{hidden, visible}) if selected != visible { t.Fatalf("expected visible portal to be selected, got %#v", selected) } } + +func TestChooseDefaultChatPortalDisabledSkipsAgentRooms(t *testing.T) { + disabled := false + client := &AIClient{ + connector: &OpenAIConnector{ + Config: Config{ + Agents: &AgentsConfig{Enabled: &disabled}, + }, + }, + } + agentPortal := &bridgev2.Portal{ + Portal: &database.Portal{ + PortalKey: networkid.PortalKey{ID: "openai:agent"}, + OtherUserID: agentUserID("beeper"), + Metadata: agentModeTestMeta("beeper"), + }, + } + modelPortal := &bridgev2.Portal{ + Portal: &database.Portal{ + PortalKey: networkid.PortalKey{ID: "openai:model"}, + OtherUserID: modelUserID("openai/gpt-5"), + Metadata: simpleModeTestMeta("openai/gpt-5"), + }, + } + + selected := client.chooseDefaultChatPortal([]*bridgev2.Portal{agentPortal, modelPortal}) + if selected != modelPortal { + t.Fatalf("expected model portal to be selected when agents are disabled, got %#v", selected) + } +} diff --git a/bridges/ai/handleai.go b/bridges/ai/handleai.go index fb982978..f095ae70 100644 --- a/bridges/ai/handleai.go +++ b/bridges/ai/handleai.go @@ -262,6 +262,9 @@ func (oc *AIClient) scheduleAutoGreeting(ctx context.Context, portal *bridgev2.P if autoGreetingBlockReason(meta) != "" { return } + if oc.agentTargetBlocked(meta) { + return + } if oc.hasPortalMessages(ctx, portal) { return } @@ -298,6 +301,10 @@ func (oc *AIClient) scheduleAutoGreeting(ctx context.Context, portal *bridgev2.P oc.Log().Debug().Stringer("room_id", roomID).Str("reason", reason).Msg("auto-greeting loop exiting: blocked by portal state") return } + if oc.agentTargetBlocked(currentMeta) { + oc.Log().Debug().Stringer("room_id", roomID).Msg("auto-greeting loop exiting: agents disabled") + return + } if oc.hasPortalMessages(bgCtx, current) { oc.Log().Debug().Stringer("room_id", roomID).Msg("auto-greeting loop exiting: portal has messages") return diff --git a/bridges/ai/handleai_test.go b/bridges/ai/handleai_test.go index 6e6217f8..df93b056 100644 --- a/bridges/ai/handleai_test.go +++ b/bridges/ai/handleai_test.go @@ -5,8 +5,12 @@ import ( "strings" "testing" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) func TestDecodeBase64Image(t *testing.T) { @@ -140,3 +144,28 @@ func TestMessageStatusReasonForError_AccessDenied403(t *testing.T) { t.Fatalf("expected no-permission reason, got %s", got) } } + +func TestDispatchInternalMessageRejectsDisabledAgentRoom(t *testing.T) { + disabled := false + oc := &AIClient{ + connector: &OpenAIConnector{ + Config: Config{ + Agents: &AgentsConfig{Enabled: &disabled}, + }, + }, + } + portal := &bridgev2.Portal{ + Portal: &database.Portal{ + MXID: id.RoomID("!room:example.com"), + PortalKey: networkid.PortalKey{ID: "openai:test"}, + }, + } + + _, _, err := oc.dispatchInternalMessage(t.Context(), portal, agentModeTestMeta("beeper"), "hello", "test", false) + if err == nil { + t.Fatal("expected dispatchInternalMessage to fail") + } + if !strings.Contains(strings.ToLower(err.Error()), "agents are disabled") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index fe15b118..b7110f07 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -35,6 +35,10 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri return nil, errors.New("portal is nil") } meta := portalMeta(portal) + if err := oc.ensureAgentTargetAllowed(meta); err != nil { + oc.sendSystemNotice(ctx, portal, err.Error()) + return nil, agentremote.UnsupportedMessageStatus(err) + } if msg.Event == nil { return nil, errors.New("missing message event") } diff --git a/bridges/ai/heartbeat_config.go b/bridges/ai/heartbeat_config.go index 26058066..43163d37 100644 --- a/bridges/ai/heartbeat_config.go +++ b/bridges/ai/heartbeat_config.go @@ -7,7 +7,7 @@ import ( ) func hasExplicitHeartbeatAgents(cfg *Config) bool { - if cfg == nil || cfg.Agents == nil { + if cfg == nil || !cfg.agentsEnabled() || cfg.Agents == nil { return false } for _, entry := range cfg.Agents.List { @@ -79,6 +79,9 @@ func isHeartbeatEnabledForAgent(cfg *Config, agentID string) bool { if cfg == nil { return resolved == defaultAgent } + if !cfg.agentsEnabled() { + return false + } if cfg.Agents == nil { return resolved == defaultAgent } diff --git a/bridges/ai/heartbeat_config_test.go b/bridges/ai/heartbeat_config_test.go index 483e808b..00ffe7c6 100644 --- a/bridges/ai/heartbeat_config_test.go +++ b/bridges/ai/heartbeat_config_test.go @@ -43,3 +43,18 @@ func TestIsHeartbeatEnabledForAgent_ExplicitAgentsMode(t *testing.T) { t.Fatalf("expected agent without heartbeat block to be disabled in explicit heartbeat mode") } } + +func TestIsHeartbeatEnabledForAgent_DisabledByConfig(t *testing.T) { + disabled := false + cfg := &Config{ + Agents: &AgentsConfig{ + Enabled: &disabled, + }, + } + if isHeartbeatEnabledForAgent(cfg, normalizeAgentID(agents.DefaultAgentID)) { + t.Fatal("expected heartbeat to be disabled when agents are disabled") + } + if got := resolveHeartbeatAgents(cfg); len(got) != 0 { + t.Fatalf("expected no heartbeat agents when disabled, got %#v", got) + } +} diff --git a/bridges/ai/heartbeat_execute.go b/bridges/ai/heartbeat_execute.go index d1297bec..744ccd06 100644 --- a/bridges/ai/heartbeat_execute.go +++ b/bridges/ai/heartbeat_execute.go @@ -30,7 +30,7 @@ type heartbeatAgent struct { func resolveHeartbeatAgents(cfg *Config) []heartbeatAgent { var list []heartbeatAgent - if cfg == nil { + if cfg == nil || !cfg.agentsEnabled() { return list } if hasExplicitHeartbeatAgents(cfg) { diff --git a/bridges/ai/integrations_config.go b/bridges/ai/integrations_config.go index ad7298b7..b229b59c 100644 --- a/bridges/ai/integrations_config.go +++ b/bridges/ai/integrations_config.go @@ -107,6 +107,7 @@ func (c *ToolApprovalsRuntimeConfig) WithDefaults() *ToolApprovalsRuntimeConfig // AgentsConfig configures agent defaults. type AgentsConfig struct { + Enabled *bool `yaml:"enabled"` Defaults *AgentDefaultsConfig `yaml:"defaults"` List []AgentEntryConfig `yaml:"list"` } @@ -566,6 +567,7 @@ func upgradeConfig(helper configupgrade.Helper) { helper.Copy(configupgrade.Str, "session", "main_key") // Agents heartbeat configuration + helper.Copy(configupgrade.Bool, "agents", "enabled") helper.Copy(configupgrade.Int, "agents", "defaults", "timeout_seconds") helper.Copy(configupgrade.Str, "agents", "defaults", "user_timezone") helper.Copy(configupgrade.Str, "agents", "defaults", "envelope_timezone") diff --git a/bridges/ai/integrations_example-config.yaml b/bridges/ai/integrations_example-config.yaml index cab66e41..554da5a0 100644 --- a/bridges/ai/integrations_example-config.yaml +++ b/bridges/ai/integrations_example-config.yaml @@ -222,6 +222,7 @@ tools: # Agent defaults. # agents: + # enabled: true # defaults: # subagents: # model: "anthropic/claude-sonnet-4.5" diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index 21fddb9e..1f442518 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -31,6 +31,9 @@ func (oc *AIClient) dispatchInternalMessage( return "", false, errors.New("missing portal metadata") } } + if err := oc.ensureAgentTargetAllowed(meta); err != nil { + return "", false, err + } trimmed := strings.TrimSpace(body) if trimmed == "" { return "", false, errors.New("message body is required") diff --git a/bridges/ai/provisioning.go b/bridges/ai/provisioning.go index 4c7c1cc7..3fafb50c 100644 --- a/bridges/ai/provisioning.go +++ b/bridges/ai/provisioning.go @@ -222,6 +222,8 @@ type agentUpsertRequest struct { func writeAgentError(w http.ResponseWriter, err error) { switch { + case errors.Is(err, errAgentsDisabled): + mautrix.MForbidden.WithMessage("%v.", err).Write(w) case errors.Is(err, agents.ErrAgentNotFound): mautrix.MNotFound.WithMessage("Agent not found.").Write(w) case errors.Is(err, agents.ErrAgentIsPreset): @@ -325,7 +327,10 @@ func agentResponse(agent *agents.AgentDefinition) *AgentDefinitionContent { return ToAgentDefinitionContent(agent) } -func listAgentsForResponse(ctx context.Context, store *AgentStoreAdapter) ([]*AgentDefinitionContent, error) { +func listAgentsForResponse(ctx context.Context, client *AIClient, store *AgentStoreAdapter) ([]*AgentDefinitionContent, error) { + if client != nil && !client.agentsEnabled() { + return []*AgentDefinitionContent{}, nil + } loaded, err := store.LoadAgents(ctx) if err != nil { return nil, err @@ -349,7 +354,7 @@ func (api *ProvisioningAPI) handleListAgents(w http.ResponseWriter, r *http.Requ if client == nil { return } - items, err := listAgentsForResponse(r.Context(), NewAgentStoreAdapter(client)) + items, err := listAgentsForResponse(r.Context(), client, NewAgentStoreAdapter(client)) if err != nil { mautrix.MUnknown.WithMessage("Couldn't list agents: %v.", err).Write(w) return @@ -362,6 +367,10 @@ func (api *ProvisioningAPI) handleGetAgent(w http.ResponseWriter, r *http.Reques if client == nil { return } + if !client.agentsEnabled() { + writeAgentError(w, client.agentFeaturesDisabledErr()) + return + } agentID := strings.TrimSpace(r.PathValue("agent_id")) agent, err := NewAgentStoreAdapter(client).GetAgentByID(r.Context(), agentID) if err != nil { @@ -376,6 +385,10 @@ func (api *ProvisioningAPI) handleCreateAgent(w http.ResponseWriter, r *http.Req if client == nil { return } + if !client.agentsEnabled() { + writeAgentError(w, client.agentFeaturesDisabledErr()) + return + } var req agentUpsertRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) @@ -404,6 +417,10 @@ func (api *ProvisioningAPI) handleUpdateAgent(w http.ResponseWriter, r *http.Req if client == nil { return } + if !client.agentsEnabled() { + writeAgentError(w, client.agentFeaturesDisabledErr()) + return + } var req agentUpsertRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) @@ -438,6 +455,10 @@ func (api *ProvisioningAPI) handleDeleteAgent(w http.ResponseWriter, r *http.Req if client == nil { return } + if !client.agentsEnabled() { + writeAgentError(w, client.agentFeaturesDisabledErr()) + return + } agentID := strings.TrimSpace(r.PathValue("agent_id")) if err := NewAgentStoreAdapter(client).DeleteAgent(r.Context(), agentID); err != nil { writeAgentError(w, err) diff --git a/bridges/ai/provisioning_agents_test.go b/bridges/ai/provisioning_agents_test.go new file mode 100644 index 00000000..f0df7715 --- /dev/null +++ b/bridges/ai/provisioning_agents_test.go @@ -0,0 +1,41 @@ +package ai + +import ( + "context" + "errors" + "net/http/httptest" + "strings" + "testing" +) + +func TestListAgentsForResponseDisabledReturnsEmpty(t *testing.T) { + disabled := false + client := newCatalogTestClient() + client.connector.Config.Agents = &AgentsConfig{Enabled: &disabled} + + items, err := listAgentsForResponse(context.Background(), client, NewAgentStoreAdapter(client)) + if err != nil { + t.Fatalf("listAgentsForResponse returned error: %v", err) + } + if len(items) != 0 { + t.Fatalf("expected no agents when disabled, got %#v", items) + } +} + +func TestWriteAgentErrorDisabledReturnsForbidden(t *testing.T) { + rec := httptest.NewRecorder() + writeAgentError(rec, errAgentsDisabled) + + if rec.Code != 403 { + t.Fatalf("expected 403, got %d", rec.Code) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "agents are disabled") { + t.Fatalf("unexpected response body: %s", rec.Body.String()) + } +} + +func TestWriteAgentErrorDisabledMatchesSentinel(t *testing.T) { + if !errors.Is(errAgentsDisabled, errAgentsDisabled) { + t.Fatal("expected sentinel error to match itself") + } +} diff --git a/bridges/ai/sdk_agent_catalog.go b/bridges/ai/sdk_agent_catalog.go index 1052e8c0..3039333c 100644 --- a/bridges/ai/sdk_agent_catalog.go +++ b/bridges/ai/sdk_agent_catalog.go @@ -22,6 +22,9 @@ func (c aiAgentCatalog) DefaultAgent(ctx context.Context, login *bridgev2.UserLo if client == nil { return nil, nil } + if !client.agentsEnabled() { + return nil, nil + } agent, err := NewAgentStoreAdapter(client).GetAgentByID(ctx, agents.DefaultAgentID) if err != nil || agent == nil { return nil, err @@ -34,6 +37,9 @@ func (c aiAgentCatalog) ListAgents(ctx context.Context, login *bridgev2.UserLogi if client == nil { return nil, nil } + if !client.agentsEnabled() { + return nil, nil + } agentsMap, err := NewAgentStoreAdapter(client).LoadAgents(ctx) if err != nil { return nil, err @@ -60,6 +66,9 @@ func (c aiAgentCatalog) ResolveAgent(ctx context.Context, login *bridgev2.UserLo if client == nil { return nil, nil } + if !client.agentsEnabled() { + return nil, nil + } agentID := normalizedCatalogAgentIdentifier(identifier) if agentID == "" { return nil, nil diff --git a/bridges/ai/sdk_agent_catalog_test.go b/bridges/ai/sdk_agent_catalog_test.go index 8eabd1be..3f313521 100644 --- a/bridges/ai/sdk_agent_catalog_test.go +++ b/bridges/ai/sdk_agent_catalog_test.go @@ -4,6 +4,7 @@ import ( "context" "slices" "testing" + "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" @@ -13,11 +14,18 @@ import ( ) func newCatalogTestClient() *AIClient { - return &AIClient{ + client := &AIClient{ UserLogin: &bridgev2.UserLogin{ UserLogin: &database.UserLogin{ ID: "login-1", Metadata: &UserLoginMetadata{ + ModelCache: &ModelCache{ + Models: []ModelInfo{ + {ID: "openai/gpt-5", Name: "GPT-5"}, + }, + LastRefresh: time.Now().Unix(), + CacheDuration: 3600, + }, CustomAgents: map[string]*AgentDefinitionContent{ "custom-agent": { ID: "custom-agent", @@ -32,6 +40,8 @@ func newCatalogTestClient() *AIClient { }, connector: &OpenAIConnector{}, } + client.SetLoggedIn(true) + return client } func TestAIAgentCatalogDefaultAgent(t *testing.T) { @@ -94,3 +104,34 @@ func TestAIAgentCatalogListsAndResolvesCustomAgents(t *testing.T) { t.Fatalf("expected avatar URL to be preserved, got %q", resolved.AvatarURL) } } + +func TestAIAgentCatalogDisabledReturnsNoAgents(t *testing.T) { + disabled := false + client := newCatalogTestClient() + client.connector.Config.Agents = &AgentsConfig{Enabled: &disabled} + catalog := client.sdkAgentCatalog() + + agent, err := catalog.DefaultAgent(context.Background(), client.UserLogin) + if err != nil { + t.Fatalf("DefaultAgent returned error: %v", err) + } + if agent != nil { + t.Fatalf("expected no default agent when disabled, got %#v", agent) + } + + agentsList, err := catalog.ListAgents(context.Background(), client.UserLogin) + if err != nil { + t.Fatalf("ListAgents returned error: %v", err) + } + if len(agentsList) != 0 { + t.Fatalf("expected no agents when disabled, got %#v", agentsList) + } + + resolved, err := catalog.ResolveAgent(context.Background(), client.UserLogin, "custom-agent") + if err != nil { + t.Fatalf("ResolveAgent returned error: %v", err) + } + if resolved != nil { + t.Fatalf("expected no resolved agent when disabled, got %#v", resolved) + } +} diff --git a/bridges/ai/sessions_tools.go b/bridges/ai/sessions_tools.go index bdc7f92a..d8027bdc 100644 --- a/bridges/ai/sessions_tools.go +++ b/bridges/ai/sessions_tools.go @@ -72,7 +72,7 @@ func (oc *AIClient) executeSessionsList(ctx context.Context, portal *bridgev2.Po continue } meta := portalMeta(candidate) - if shouldExcludeModelVisiblePortal(meta) { + if oc.shouldExcludeVisiblePortal(meta) { continue } kind := resolveSessionKind(currentRoomID, candidate, meta) @@ -541,7 +541,7 @@ func (oc *AIClient) resolveSessionPortalByLabel(ctx context.Context, label strin continue } meta := portalMeta(candidate) - if shouldExcludeModelVisiblePortal(meta) { + if oc.shouldExcludeVisiblePortal(meta) { continue } if filterAgent != "" { diff --git a/cmd/agentremote/bridges.go b/cmd/agentremote/bridges.go index bde0c8e3..01b0d4a0 100644 --- a/cmd/agentremote/bridges.go +++ b/cmd/agentremote/bridges.go @@ -13,32 +13,69 @@ import ( type bridgeDef struct { bridgeentry.Definition - NewFunc func() bridgev2.NetworkConnector + NewFunc func() bridgev2.NetworkConnector + RuntimeBridgeType string + RemoteBridgeType string + ConfigOverrides map[string]any } var bridgeRegistry = map[string]bridgeDef{ "ai": { - Definition: bridgeentry.AI, - NewFunc: func() bridgev2.NetworkConnector { return aibridge.NewAIConnector() }, + Definition: bridgeentry.AI, + NewFunc: func() bridgev2.NetworkConnector { return aibridge.NewAIConnector() }, + RuntimeBridgeType: "ai", + RemoteBridgeType: "ai", + ConfigOverrides: map[string]any{ + "network.agents.enabled": false, + }, + }, + "agent": { + Definition: bridgeentry.Agent, + NewFunc: func() bridgev2.NetworkConnector { return aibridge.NewAIConnector() }, + RuntimeBridgeType: "ai", + RemoteBridgeType: "ai", + ConfigOverrides: map[string]any{ + "network.agents.enabled": true, + }, }, "codex": { - Definition: bridgeentry.Codex, - NewFunc: func() bridgev2.NetworkConnector { return codex.NewConnector() }, + Definition: bridgeentry.Codex, + NewFunc: func() bridgev2.NetworkConnector { return codex.NewConnector() }, + RemoteBridgeType: "codex", }, "opencode": { - Definition: bridgeentry.OpenCode, - NewFunc: func() bridgev2.NetworkConnector { return opencode.NewConnector() }, + Definition: bridgeentry.OpenCode, + NewFunc: func() bridgev2.NetworkConnector { return opencode.NewConnector() }, + RemoteBridgeType: "opencode", }, "openclaw": { - Definition: bridgeentry.OpenClaw, - NewFunc: func() bridgev2.NetworkConnector { return openclaw.NewConnector() }, + Definition: bridgeentry.OpenClaw, + NewFunc: func() bridgev2.NetworkConnector { return openclaw.NewConnector() }, + RemoteBridgeType: "openclaw", }, "dummybridge": { - Definition: bridgeentry.DummyBridge, - NewFunc: func() bridgev2.NetworkConnector { return dummybridge.NewConnector() }, + Definition: bridgeentry.DummyBridge, + NewFunc: func() bridgev2.NetworkConnector { return dummybridge.NewConnector() }, + RemoteBridgeType: "dummybridge", }, } +func remoteBridgeType(localBridgeType string) string { + def, ok := bridgeRegistry[localBridgeType] + if !ok || def.RemoteBridgeType == "" { + return localBridgeType + } + return def.RemoteBridgeType +} + +func runtimeBridgeType(localBridgeType string) string { + def, ok := bridgeRegistry[localBridgeType] + if !ok || def.RuntimeBridgeType == "" { + return localBridgeType + } + return def.RuntimeBridgeType +} + func beeperBridgeName(bridgeType, name string) string { if name == "" { return "sh-" + bridgeType diff --git a/cmd/agentremote/commands.go b/cmd/agentremote/commands.go index 8b6af4c1..b85f726a 100644 --- a/cmd/agentremote/commands.go +++ b/cmd/agentremote/commands.go @@ -101,6 +101,7 @@ func initCommands() { }, Examples: []string{ "agentremote start ai", + "agentremote start agent", "agentremote start codex --name test", "agentremote start opencode --profile work", "agentremote start ai --wait", @@ -122,6 +123,7 @@ func initCommands() { }, Examples: []string{ "agentremote up ai", + "agentremote up agent", "agentremote up codex --name test", }, Run: cmdUp, @@ -138,6 +140,7 @@ func initCommands() { }, Examples: []string{ "agentremote run ai", + "agentremote run agent", "agentremote run codex --name dev", }, Run: cmdRun, @@ -154,6 +157,7 @@ func initCommands() { }, Examples: []string{ "agentremote init ai", + "agentremote init agent", "agentremote init openclaw --name dev", }, Run: cmdInit, @@ -240,6 +244,7 @@ func initCommands() { }, Examples: []string{ "agentremote register ai", + "agentremote register agent", "agentremote register codex --name dev --json", }, Run: cmdRegister, diff --git a/cmd/agentremote/main.go b/cmd/agentremote/main.go index 4406940f..5a35fd48 100644 --- a/cmd/agentremote/main.go +++ b/cmd/agentremote/main.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "io" + "maps" "os" "os/exec" "path/filepath" @@ -437,7 +438,7 @@ func cmdRun(args []string) error { if err != nil { return fmt.Errorf("failed to find own executable: %w", err) } - argv := []string{exe, "__bridge", bridgeType, "-c", meta.ConfigPath} + argv := []string{exe, "__bridge", runtimeBridgeType(bridgeType), "-c", meta.ConfigPath} fmt.Printf("running %s in foreground\n", instName) cliutil.PrintRuntimePaths(meta) if err = os.Chdir(filepath.Dir(meta.ConfigPath)); err != nil { @@ -1120,6 +1121,7 @@ func ensureInitialized(instName, bridgeType, beeperName string, sp *instancePath "beeper.com": "admin", }, } + maps.Copy(overrides, def.ConfigOverrides) if err = bridgeutil.ApplyConfigOverrides(meta.ConfigPath, overrides); err != nil { return nil, err } @@ -1151,7 +1153,7 @@ func generateExampleConfig(meta *metadata) error { if err != nil { return fmt.Errorf("failed to find own executable: %w", err) } - cmd := exec.Command(exe, "__bridge", meta.BridgeType, "-c", meta.ConfigPath, "-e") + cmd := exec.Command(exe, "__bridge", runtimeBridgeType(meta.BridgeType), "-c", meta.ConfigPath, "-e") cmd.Dir = filepath.Dir(meta.ConfigPath) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -1185,7 +1187,7 @@ func ensureRegistration(profile, envOverride string, meta *metadata, bridgeType ConfigPath: meta.ConfigPath, RegistrationPath: meta.RegistrationPath, BeeperBridgeName: meta.BeeperBridgeName, - BridgeType: bridgeType, + BridgeType: remoteBridgeType(bridgeType), DBName: bridgeRegistry[bridgeType].DBName, }) } @@ -1210,5 +1212,5 @@ func startBridgeProcess(meta *metadata, bridgeType string) error { if err != nil { return fmt.Errorf("failed to find own executable: %w", err) } - return bridgeutil.StartBridgeFromConfig(exe, []string{"__bridge", bridgeType, "-c", meta.ConfigPath}, meta.ConfigPath, meta.LogPath, meta.PIDPath) + return bridgeutil.StartBridgeFromConfig(exe, []string{"__bridge", runtimeBridgeType(bridgeType), "-c", meta.ConfigPath}, meta.ConfigPath, meta.LogPath, meta.PIDPath) } diff --git a/cmd/agentremote/run_bridge.go b/cmd/agentremote/run_bridge.go index 80b3f635..7280b966 100644 --- a/cmd/agentremote/run_bridge.go +++ b/cmd/agentremote/run_bridge.go @@ -23,7 +23,7 @@ func cmdInternalBridge(args []string) error { // Replace os.Args so mxmain sees: [bridge-flags...] // e.g. agentremote __bridge ai -c config.yaml → ai -c config.yaml os.Args = append([]string{def.Name}, args[1:]...) - if bridgeType == "ai" { + if bridgeType == "ai" || bridgeType == "agent" { bridgev2.PortalEventBuffer = 0 } diff --git a/cmd/internal/bridgeentry/bridgeentry.go b/cmd/internal/bridgeentry/bridgeentry.go index b5114af5..3a84af56 100644 --- a/cmd/internal/bridgeentry/bridgeentry.go +++ b/cmd/internal/bridgeentry/bridgeentry.go @@ -24,6 +24,12 @@ var ( Port: 29345, DBName: "ai.db", } + Agent = Definition{ + Name: "agent", + Description: "Agent-enabled mode for the AI bridge built on mautrix-go bridgev2.", + Port: 29350, + DBName: "agent.db", + } Codex = Definition{ Name: "codex", Description: "A Matrix↔Codex bridge built on mautrix-go bridgev2.", diff --git a/config.example.yaml b/config.example.yaml index e5c6a761..5c4a4028 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -199,6 +199,7 @@ default_system_prompt: | # Agent defaults (OpenClaw-style). # agents: + # enabled: true # defaults: # subagents: # model: "anthropic/claude-sonnet-4.5"