Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions bridges.manifest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ instances:
bridge.permissions:
"*": relay
beeper.com: admin
network.agents.enabled: false

agent:
bridge_type: ai
mode: local-repo
repo_path: .
build_cmd: ./build.sh
binary_path: ./ai
beeper_bridge_name: sh-agent
config_overrides:
appservice.address: websocket
appservice.hostname: 127.0.0.1
appservice.port: 29350
database.type: sqlite3-fk-wal
database.uri: file:agent.db?_txlock=immediate
bridge.permissions:
"*": relay
beeper.com: admin
network.agents.enabled: true

codex:
bridge_type: codex
Expand Down
7 changes: 5 additions & 2 deletions bridges/ai/agent_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ func (oc *AIClient) recordAgentActivity(ctx context.Context, portal *bridgev2.Po
if heartbeatRunFromContext(ctx) != nil {
return
}
if oc.agentTargetBlocked(meta) {
return
}
agentID := normalizeAgentID(resolveAgentID(meta))
if agentID == "" {
return
Expand Down Expand Up @@ -93,12 +96,12 @@ func (oc *AIClient) defaultChatPortal() *bridgev2.Portal {
Receiver: oc.UserLogin.ID,
}
if portal, err := oc.UserLogin.Bridge.GetPortalByKey(ctx, portalKey); err == nil && portal != nil {
if isDefaultChatCandidate(portal) {
if oc.isDefaultChatCandidate(portal) {
return portal
}
}
}
if portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultChatPortalKey(oc.UserLogin.ID)); err == nil && portal != nil && isDefaultChatCandidate(portal) {
if portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultChatPortalKey(oc.UserLogin.ID)); err == nil && portal != nil && oc.isDefaultChatCandidate(portal) {
return portal
}
return nil
Expand Down
80 changes: 80 additions & 0 deletions bridges/ai/agent_mode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package ai

import (
"errors"

"maunium.net/go/mautrix/bridgev2"
)

var errAgentsDisabled = errors.New("agents are disabled by bridge config")

func (c *Config) agentsEnabled() bool {
if c == nil || c.Agents == nil || c.Agents.Enabled == nil {
return true
}
return *c.Agents.Enabled
}

func (oc *OpenAIConnector) agentsEnabled() bool {
if oc == nil {
return true
}
return oc.Config.agentsEnabled()
}

func (oc *AIClient) agentsEnabled() bool {
if oc == nil || oc.connector == nil {
return true
}
return oc.connector.agentsEnabled()
}

func (oc *AIClient) agentFeaturesDisabledErr() error {
return errAgentsDisabled
}

func (oc *AIClient) agentTargetBlocked(meta *PortalMetadata) bool {
return oc != nil && !oc.agentsEnabled() && resolveAgentID(meta) != ""
}

func (oc *AIClient) ensureAgentTargetAllowed(meta *PortalMetadata) error {
if oc.agentTargetBlocked(meta) {
return oc.agentFeaturesDisabledErr()
}
return nil
}

func (oc *AIClient) shouldExcludeVisiblePortal(meta *PortalMetadata) bool {
if shouldExcludeModelVisiblePortal(meta) {
return true
}
return oc.agentTargetBlocked(meta)
}

func (oc *AIClient) isDefaultChatCandidate(portal *bridgev2.Portal) bool {
return portal != nil && !oc.shouldExcludeVisiblePortal(portalMeta(portal))
}

func (oc *AIClient) chooseDefaultChatPortal(portals []*bridgev2.Portal) *bridgev2.Portal {
var defaultPortal *bridgev2.Portal
var (
minIdx int
haveSlug bool
)
for _, portal := range portals {
if !oc.isDefaultChatCandidate(portal) {
continue
}
pm := portalMeta(portal)
if idx, ok := parseChatSlug(pm.Slug); ok {
if !haveSlug || idx < minIdx {
minIdx = idx
defaultPortal = portal
haveSlug = true
}
} else if defaultPortal == nil && !haveSlug {
defaultPortal = portal
}
}
return defaultPortal
}
91 changes: 34 additions & 57 deletions bridges/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho
return resp.Chat, nil
}
if agentID, ok := parseAgentFromGhostID(ghostID); ok {
if !oc.agentsEnabled() {
return nil, bridgev2.WrapRespErr(oc.agentFeaturesDisabledErr(), mautrix.MForbidden)
}
store := NewAgentStoreAdapter(oc)
agent, err := store.GetAgentByID(ctx, agentID)
if err != nil || agent == nil {
Expand All @@ -425,6 +428,9 @@ func (oc *AIClient) CreateChatWithGhost(ctx context.Context, ghost *bridgev2.Gho

// resolveAgentIdentifier resolves an agent to a ghost and optionally creates a chat.
func (oc *AIClient) resolveAgentIdentifier(ctx context.Context, agent *agents.AgentDefinition, modelID string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) {
if !oc.agentsEnabled() {
return nil, oc.agentFeaturesDisabledErr()
}
explicitModel := modelID != ""
if modelID == "" {
modelID = oc.agentDefaultModel(agent)
Expand Down Expand Up @@ -730,6 +736,9 @@ func (oc *AIClient) resolveNewChatTarget(
if cmd != "agent" {
return nil, "", errors.New(usage)
}
if !oc.agentsEnabled() {
return nil, "", oc.agentFeaturesDisabledErr()
}
targetID := args[1]
if targetID == "" || len(args) > 2 {
return nil, "", errors.New(usage)
Expand All @@ -753,6 +762,9 @@ func (oc *AIClient) resolveNewChatTarget(
}
agentID := resolveAgentID(meta)
if agentID != "" {
if !oc.agentsEnabled() {
return nil, "", oc.agentFeaturesDisabledErr()
}
store := NewAgentStoreAdapter(oc)
agent, err := store.GetAgentByID(ctx, agentID)
if err != nil || agent == nil {
Expand Down Expand Up @@ -1093,7 +1105,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error {
if err != nil {
oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to load default chat portal by ID")
} else if portal != nil {
if !isDefaultChatCandidate(portal) {
if !oc.isDefaultChatCandidate(portal) {
deterministicPortalBlocked = portal.PortalKey == defaultPortalKey
oc.loggerForContext(ctx).Warn().Stringer("portal", portal.PortalKey).Msg("Ignoring hidden portal stored as default chat")
loginMeta.DefaultChatPortalID = ""
Expand Down Expand Up @@ -1121,7 +1133,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error {
portal, err := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultPortalKey)
if err != nil {
oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to load default chat portal by deterministic key")
} else if portal != nil && isDefaultChatCandidate(portal) {
} else if portal != nil && oc.isDefaultChatCandidate(portal) {
return oc.ensureExistingChatPortalReady(ctx, loginMeta, portal, "Existing default chat already has MXID", "Default chat missing MXID; creating Matrix room", "Failed to create Matrix room for default chat")
} else if portal != nil {
deterministicPortalBlocked = true
Expand All @@ -1135,23 +1147,13 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error {
return err
}

defaultPortal := chooseDefaultChatPortal(portals)
defaultPortal := oc.chooseDefaultChatPortal(portals)

if defaultPortal != nil {
return oc.ensureExistingChatPortalReady(ctx, loginMeta, defaultPortal, "Existing chat already has MXID", "Existing portal missing MXID; creating Matrix room", "Failed to create Matrix room for existing portal")
}

// Create default chat with Beep agent
beeperAgent := agents.GetBeeperAI()
if beeperAgent == nil {
return errors.New("beeper AI agent not found")
}

// Determine model from agent config or use default
modelID := beeperAgent.Model.Primary
if modelID == "" {
modelID = oc.effectiveModel(nil)
}
modelID := oc.effectiveModel(nil)

initOpts := PortalInitOpts{
ModelID: modelID,
Expand Down Expand Up @@ -1186,23 +1188,26 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error {
return err
}

// Set agent-specific metadata
pm := portalMeta(portal)
if oc.agentsEnabled() {
beeperAgent := agents.GetBeeperAI()
if beeperAgent == nil {
return errors.New("beeper AI agent not found")
}

// Update the OtherUserID to be the agent ghost
agentGhostID := oc.agentUserID(beeperAgent.ID)
portal.OtherUserID = agentGhostID
pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID)
pm := portalMeta(portal)
agentGhostID := oc.agentUserID(beeperAgent.ID)
portal.OtherUserID = agentGhostID
pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID)

if err := portal.Save(ctx); err != nil {
oc.loggerForContext(ctx).Err(err).Msg("Failed to save portal with agent config")
return err
}
if err := portal.Save(ctx); err != nil {
oc.loggerForContext(ctx).Err(err).Msg("Failed to save portal with agent config")
return err
}

// Update chat info members to use agent ghost only
agentName := oc.resolveAgentDisplayName(ctx, beeperAgent)
oc.applyAgentChatInfo(chatInfo, beeperAgent.ID, agentName, modelID)
oc.ensureAgentGhostDisplayName(ctx, beeperAgent.ID, modelID, agentName)
agentName := oc.resolveAgentDisplayName(ctx, beeperAgent)
oc.applyAgentChatInfo(chatInfo, beeperAgent.ID, agentName, modelID)
oc.ensureAgentGhostDisplayName(ctx, beeperAgent.ID, modelID, agentName)
}

loginMeta.DefaultChatPortalID = string(portal.PortalKey.ID)
if err := oc.UserLogin.Save(ctx); err != nil {
Expand All @@ -1218,7 +1223,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error {
}

func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta *UserLoginMetadata, portal *bridgev2.Portal, readyMsg string, createMsg string, errMsg string) error {
if !isDefaultChatCandidate(portal) {
if !oc.isDefaultChatCandidate(portal) {
return fmt.Errorf("portal %s is hidden and can't be selected as default chat", portal.PortalKey)
}
if loginMeta != nil {
Expand All @@ -1241,34 +1246,6 @@ func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta
return nil
}

func isDefaultChatCandidate(portal *bridgev2.Portal) bool {
return portal != nil && !shouldExcludeModelVisiblePortal(portalMeta(portal))
}

func chooseDefaultChatPortal(portals []*bridgev2.Portal) *bridgev2.Portal {
var defaultPortal *bridgev2.Portal
var (
minIdx int
haveSlug bool
)
for _, portal := range portals {
if !isDefaultChatCandidate(portal) {
continue
}
pm := portalMeta(portal)
if idx, ok := parseChatSlug(pm.Slug); ok {
if !haveSlug || idx < minIdx {
minIdx = idx
defaultPortal = portal
haveSlug = true
}
} else if defaultPortal == nil && !haveSlug {
defaultPortal = portal
}
}
return defaultPortal
}

func (oc *AIClient) listAllChatPortals(ctx context.Context) ([]*bridgev2.Portal, error) {
// Query all portals and filter by receiver (our login ID)
// This works because all our portals have Receiver set to our UserLogin.ID
Expand Down
43 changes: 43 additions & 0 deletions bridges/ai/chat_login_redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ import (
"testing"
)

func newDiscoveryTestClient(agentsEnabled bool) *AIClient {
client := newCatalogTestClient()
if !agentsEnabled {
disabled := false
client.connector.Config.Agents = &AgentsConfig{Enabled: &disabled}
}
return client
}

func TestSearchUsersRequiresLogin(t *testing.T) {
oc := &AIClient{}
_, err := oc.SearchUsers(context.Background(), "gpt")
Expand All @@ -29,6 +38,40 @@ func TestGetContactListRequiresLogin(t *testing.T) {
}
}

func TestSearchUsersDisabledKeepsModelsButHidesAgents(t *testing.T) {
oc := newDiscoveryTestClient(false)

results, err := oc.SearchUsers(context.Background(), "gpt")
if err != nil {
t.Fatalf("SearchUsers returned error: %v", err)
}
if len(results) == 0 {
t.Fatal("expected model search results when agents are disabled")
}
for _, result := range results {
if result != nil && strings.HasPrefix(string(result.UserID), "agent-") {
t.Fatalf("expected agent results to be hidden, got %#v", result)
}
}
}

func TestGetContactListDisabledKeepsModelsButHidesAgents(t *testing.T) {
oc := newDiscoveryTestClient(false)

results, err := oc.GetContactList(context.Background())
if err != nil {
t.Fatalf("GetContactList returned error: %v", err)
}
if len(results) == 0 {
t.Fatal("expected model contacts when agents are disabled")
}
for _, result := range results {
if result != nil && strings.HasPrefix(string(result.UserID), "agent-") {
t.Fatalf("expected agent contacts to be hidden, got %#v", result)
}
}
}

func TestModelRedirectTarget(t *testing.T) {
tests := []struct {
name string
Expand Down
21 changes: 21 additions & 0 deletions bridges/ai/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,24 @@ func TestInboundConfig_WithDefaults_PartialValues(t *testing.T) {
t.Errorf("Expected default DefaultDebounceMs %d, got %d", DefaultDebounceMs, result.DefaultDebounceMs)
}
}

func TestConfigAgentsEnabledDefaultsToTrue(t *testing.T) {
if !(new(Config)).agentsEnabled() {
t.Fatal("expected agents to be enabled by default")
}
if (&Config{Agents: &AgentsConfig{}}).agentsEnabled() == false {
t.Fatal("expected missing agents.enabled to default to true")
}
}

func TestConfigAgentsEnabledCanBeDisabled(t *testing.T) {
disabled := false
cfg := &Config{
Agents: &AgentsConfig{
Enabled: &disabled,
},
}
if cfg.agentsEnabled() {
t.Fatal("expected agents to be disabled")
}
}
Loading
Loading