Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/dedalus_labs/lib/mcp/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dedalus_labs.types.shared_params.mcp_servers import MCPServerItem
from dedalus_labs.types.shared_params.mcp_server_spec import MCPServerSpec

from .wire import serialize_mcp_servers
from .wire import serialize_mcp_servers, slug_to_connection_name
from ..crypto import encrypt_credentials, fetch_encryption_key, fetch_encryption_key_sync
from .protocols import CredentialProtocol

Expand Down Expand Up @@ -161,35 +161,46 @@ def _encrypt_credentials(
return EncryptedCredentials(**encrypted)


def _credentials_for_server(
name: str,
all_creds: Dict[str, str],
) -> Optional[Dict[str, str]]:
"""Return the subset of *all_creds* that belongs to *name*, or None."""
conn = slug_to_connection_name(name)
blob = all_creds.get(conn)
return {conn: blob} if blob is not None else None


def _embed_credentials(
servers: List[MCPServerItem],
encrypted: EncryptedCredentials,
) -> List[MCPServerSpec]:
"""Embed encrypted credentials into each server spec.

Converts slug strings to full specs and adds credentials to all servers.
Each server receives only its own credentials, matched by connection name
via :func:`~dedalus_labs.lib.mcp.wire.slug_to_connection_name`.

Args:
servers: Serialized MCP servers (slug strings or spec dicts).
encrypted: EncryptedCredentials instance.

Returns:
List of MCPServerSpec dicts with credentials embedded.
List of MCPServerSpec dicts with per-server credentials embedded.

"""
creds_dict = encrypted.to_dict()
all_creds = encrypted.to_dict()
result: List[MCPServerSpec] = []

for server in servers:
if isinstance(server, str):
creds = _credentials_for_server(server, all_creds)
if server.startswith(("http://", "https://")):
result.append({"url": server, "name": server, "credentials": creds_dict})
result.append({"url": server, "name": server, "credentials": creds})
else:
result.append({"slug": server, "name": server, "credentials": creds_dict})
result.append({"slug": server, "name": server, "credentials": creds})
elif isinstance(server, dict):
# Existing spec -> add name (if missing) and credentials
name = server.get("name") or server.get("slug") or server.get("url") or ""
spec: MCPServerSpec = {**server, "name": name, "credentials": creds_dict}
result.append(spec)
creds = _credentials_for_server(name, all_creds)
result.append({**server, "name": name, "credentials": creds})

return result
16 changes: 16 additions & 0 deletions src/dedalus_labs/lib/mcp/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Helpers
"build_connection_record",
"collect_unique_connections",
"slug_to_connection_name",
]


Expand Down Expand Up @@ -370,6 +371,21 @@ def collect_unique_connections(servers: Sequence[MCPServerProtocol]) -> List[Any
return unique


def slug_to_connection_name(slug: str) -> str:
"""Derive the canonical connection name from a server slug.

Slugs use ``org/server`` format; connection names use dashes.

Args:
slug: Server slug, URL, or name string.

Returns:
Connection name with slashes replaced by dashes.

"""
return slug.replace("/", "-")


# --- Credential Matching ---


Expand Down
122 changes: 122 additions & 0 deletions tests/test_mcp_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# ==============================================================================
# © 2025 Dedalus Labs, Inc. and affiliates
# Licensed under MIT
# github.com/dedalus-labs/dedalus-sdk-python/LICENSE
# ==============================================================================

"""Tests for MCP request credential embedding."""

from __future__ import annotations

from typing import TypedDict
from dataclasses import dataclass

import pytest

from dedalus_labs.lib.mcp import request as mcp_request
from dedalus_labs.lib.mcp.wire import slug_to_connection_name


@dataclass(frozen=True)
class _FakeConnection:
name: str


@dataclass(frozen=True)
class _FakeCredential:
connection: _FakeConnection
encrypted_blob: str

def values_for_encryption(self) -> dict[str, str]:
return {"blob": self.encrypted_blob}


class _RequestPayload(TypedDict):
mcp_servers: list[str]
credentials: list[_FakeCredential]


class _FakeHTTPClient:
pass


class _FakePublicKey:
pass


def _build_payload() -> _RequestPayload:
return {
"mcp_servers": [
"dedalus-labs/gmail-mcp",
"dedalus-labs/slack-mcp",
"dedalus-labs/calendar-mcp",
],
"credentials": [
_FakeCredential(_FakeConnection("dedalus-labs-gmail-mcp"), "enc-gmail"),
_FakeCredential(_FakeConnection("dedalus-labs-slack-mcp"), "enc-slack"),
],
}


def _expected_mcp_servers() -> list[dict[str, str | dict[str, str] | None]]:
return [
{
"slug": "dedalus-labs/gmail-mcp",
"name": "dedalus-labs/gmail-mcp",
"credentials": {"dedalus-labs-gmail-mcp": "enc-gmail"},
},
{
"slug": "dedalus-labs/slack-mcp",
"name": "dedalus-labs/slack-mcp",
"credentials": {"dedalus-labs-slack-mcp": "enc-slack"},
},
{
"slug": "dedalus-labs/calendar-mcp",
"name": "dedalus-labs/calendar-mcp",
"credentials": None,
},
]


def _fake_encrypt_credentials(_public_key: _FakePublicKey, values: dict[str, str]) -> str:
return values["blob"]


class TestSlugToConnectionName:
def test_slug_to_connection_name(self) -> None:
assert slug_to_connection_name("dedalus-labs/gmail-mcp") == "dedalus-labs-gmail-mcp"


class TestPrepareMCPRequest:
def test_sync_embeds_per_server_credentials(self, monkeypatch: pytest.MonkeyPatch) -> None:
def _fake_fetch_sync(_http: _FakeHTTPClient, _url: str) -> _FakePublicKey:
return _FakePublicKey()

monkeypatch.setattr(mcp_request, "fetch_encryption_key_sync", _fake_fetch_sync)
monkeypatch.setattr(mcp_request, "encrypt_credentials", _fake_encrypt_credentials)

result = mcp_request.prepare_mcp_request_sync(
_build_payload(),
"https://auth.example.com",
_FakeHTTPClient(),
)

assert "credentials" not in result
assert result["mcp_servers"] == _expected_mcp_servers()

@pytest.mark.asyncio
async def test_async_embeds_per_server_credentials(self, monkeypatch: pytest.MonkeyPatch) -> None:
async def _fake_fetch_async(_http: _FakeHTTPClient, _url: str) -> _FakePublicKey:
return _FakePublicKey()

monkeypatch.setattr(mcp_request, "fetch_encryption_key", _fake_fetch_async)
monkeypatch.setattr(mcp_request, "encrypt_credentials", _fake_encrypt_credentials)

result = await mcp_request.prepare_mcp_request(
_build_payload(),
"https://auth.example.com",
_FakeHTTPClient(),
)

assert "credentials" not in result
assert result["mcp_servers"] == _expected_mcp_servers()
Loading