Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions strands-py-wasm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
# Imports as `wasmtime`. Pinned to pgrayy/wasm-deps git URL until upstream PRs land.
"pgrayy-wasmtime @ git+https://github.com/pgrayy/wasm-deps.git@4b5dc41512109ebafe4c4f1edd592c739872c640#subdirectory=wasmtime-py",
# Imports as `wasmtime`. Pinned to a pgrayy/wasm-deps release wheel until upstream PRs land.
"pgrayy-wasmtime @ https://github.com/pgrayy/wasm-deps/releases/download/wasmtime-v46.0.7/pgrayy_wasmtime-46.0.7-py3-none-macosx_11_0_arm64.whl",
Comment thread
pgrayy marked this conversation as resolved.
"botocore>=1.29.0,<2.0.0",
]

Expand Down
83 changes: 31 additions & 52 deletions strands-py-wasm/src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,40 +266,24 @@ def __init__(self, *, interrupt_id: str, response: Any) -> None:
super().__init__(interrupt_id=interrupt_id, response=payload)


class PydanticTool:
Comment thread
pgrayy marked this conversation as resolved.
"""Tool whose input schema is derived from a pydantic ``BaseModel``."""
class DecoratedTool:
"""A Python function exposed to the agent as a tool.

def __init__(
self,
*,
name: str,
description: str,
input_model: type,
func: Callable[..., Any],
) -> None:
if not hasattr(input_model, "model_json_schema") or not hasattr(input_model, "model_validate"):
raise TypeError(f"input_model must be a pydantic BaseModel subclass; got {input_model!r}")
self.name = name
self.description = description
self._input_model = input_model
self.input_schema = input_model.model_json_schema()
self.func = func
Build one with the ``@tool`` decorator, then pass it to
:class:`Agent` via ``tools=[...]``. The agent will call the function
when the model invokes the tool by name.

def to_spec(self) -> types.ToolSpec:
return types.ToolSpec(
name=self.name,
description=self.description,
input_schema=json.dumps(self.input_schema),
)
Callbacks must be synchronous. Async functions are not yet supported.
Comment thread
pgrayy marked this conversation as resolved.

def invoke(self, raw_input: str) -> list[Any]:
payload = json.loads(raw_input) if raw_input else {}
validated = self._input_model.model_validate(payload)
return _coerce_tool_result(self.func(validated))
Example::

@tool
def get_weather(city: str) -> str:
\"\"\"Return the current weather for a city.\"\"\"
return f"It is 72F and sunny in {city}."

class Tool:
"""Registered tool: spec plus Python callable."""
agent = Agent(model=BedrockModel(...), tools=[get_weather])
"""

def __init__(
self,
Expand All @@ -321,12 +305,13 @@ def to_spec(self) -> types.ToolSpec:
input_schema=json.dumps(self.input_schema),
)

def invoke(self, raw_input: str) -> list[Any]:
def invoke(self, raw_input: str) -> list[types.ToolResultContent]:
"""Run the tool with a JSON object of keyword arguments."""
kwargs = json.loads(raw_input) if raw_input else {}
return _coerce_tool_result(self.func(**kwargs))
return _normalize_tool_result(self.func(**kwargs))


def _coerce_tool_result(result: Any) -> list[Any]:
def _normalize_tool_result(result: Any) -> list[types.ToolResultContent]:
if isinstance(result, str):
return [types.ToolResultContent.Text(types.TextBlock(text=result))]
if isinstance(result, types.TextBlock):
Expand All @@ -348,9 +333,12 @@ def tool(
name: str | None = None,
description: str | None = None,
) -> Any:
"""Decorator that turns a Python function into a :class:`Tool`."""
"""Decorator that turns a Python function into a :class:`DecoratedTool`.

Only synchronous functions are supported at this time.
"""

def wrap(f: Callable[..., Any]) -> Tool:
def wrap(f: Callable[..., Any]) -> DecoratedTool:
hints = get_type_hints(f)
sig = inspect.signature(f)
properties: dict[str, Any] = {}
Expand All @@ -362,7 +350,7 @@ def wrap(f: Callable[..., Any]) -> Tool:
schema: dict[str, Any] = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return Tool(
return DecoratedTool(
name=name or f.__name__,
description=description or (f.__doc__ or "").strip() or f.__name__,
input_schema=schema,
Expand All @@ -372,19 +360,10 @@ def wrap(f: Callable[..., Any]) -> Tool:
return wrap(func) if func is not None else wrap


_ToolInput = Tool | PydanticTool | Callable[..., Any]
# String shorthand picks a tool by name; otherwise pass a tagged ToolChoice arm.
_ToolChoiceInput = str | types.ToolChoice | None


def _coerce_tool(item: _ToolInput) -> Tool | PydanticTool:
if isinstance(item, (Tool, PydanticTool)):
return item
if callable(item):
return tool(item)
raise TypeError(f"unsupported tool: {type(item).__name__}")


class Agent:
"""Strands agent. Construct once; call :meth:`invoke` or :meth:`stream_async`."""

Expand All @@ -394,7 +373,7 @@ def __init__(
model: types.ModelInput | None = None,
messages: list[types.Message] | None = None,
system_prompt: PromptInput | None = None,
tools: list[_ToolInput] | None = None,
tools: list[DecoratedTool] | None = None,
agent_tools: list[types.AgentAsToolConfig] | None = None,
vended_tools: list[types.VendedToolInput] | None = None,
vended_plugins: list[types.VendedPluginInput] | None = None,
Expand All @@ -413,7 +392,7 @@ def __init__(
app_state: dict[str, Any] | None = None,
model_state: dict[str, Any] | None = None,
) -> None:
self._tools: list[Tool | PydanticTool] = [_coerce_tool(t) for t in (tools or [])]
self._tools: list[DecoratedTool] = list(tools or [])
identity = None
if name is not None or id is not None or description is not None:
identity = types.AgentIdentity(name=name, id=id, description=description)
Expand Down Expand Up @@ -467,7 +446,7 @@ async def _ensure_runtime_async(self) -> Any:
await rt.async_init()
return rt

def _lookup_tool(self, name: str) -> Tool | PydanticTool:
def _lookup_tool(self, name: str) -> DecoratedTool:
for t in self._tools:
if getattr(t, "name", None) == name:
return t
Expand All @@ -476,11 +455,11 @@ def _lookup_tool(self, name: str) -> Tool | PydanticTool:
def _build_invoke_args(
self,
prompt: PromptInput,
tools: list[_ToolInput] | None,
tools: list[DecoratedTool] | None,
tool_choice: _ToolChoiceInput,
structured_output_schema: str | None,
) -> types.InvokeArgs:
extra_tools = [_coerce_tool(t).to_spec() for t in (tools or [])] or None
extra_tools = [t.to_spec() for t in (tools or [])] or None
return types.InvokeArgs(
input=_marshalling.coerce_prompt(prompt),
tools=extra_tools,
Expand All @@ -492,7 +471,7 @@ async def stream_async(
self,
prompt: PromptInput,
*,
tools: list[_ToolInput] | None = None,
tools: list[DecoratedTool] | None = None,
tool_choice: _ToolChoiceInput = None,
structured_output_schema: str | None = None,
) -> AsyncIterator[types.StreamEvent]:
Expand All @@ -507,7 +486,7 @@ async def invoke_async(
self,
prompt: PromptInput,
*,
tools: list[_ToolInput] | None = None,
tools: list[DecoratedTool] | None = None,
tool_choice: _ToolChoiceInput = None,
structured_output_schema: str | None = None,
) -> AgentResult:
Expand All @@ -526,7 +505,7 @@ def invoke(
self,
prompt: PromptInput,
*,
tools: list[_ToolInput] | None = None,
tools: list[DecoratedTool] | None = None,
tool_choice: _ToolChoiceInput = None,
structured_output_schema: str | None = None,
) -> AgentResult:
Expand Down
29 changes: 29 additions & 0 deletions strands-py-wasm/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from strands import Agent, BedrockModel, tool


@pytest.fixture
def model():
return BedrockModel("us.anthropic.claude-haiku-4-5-20251001-v1:0")


@pytest.fixture
def weather_tool():
@tool
def get_weather(city: str) -> str:
"""Return the current weather for a city."""
return f"It is 72F and sunny in {city}."

return get_weather


@pytest.fixture
def agent(model, weather_tool):
return Agent(model=model, tools=[weather_tool])


@pytest.mark.asyncio
async def test_decorated_tool_invocation(agent):
result = await agent.invoke_async("What is the weather in Seattle?")
assert "72" in str(result)
88 changes: 69 additions & 19 deletions strands-wasm/entry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
/// <reference path="./generated/interfaces/strands-agent-tool-provider.d.ts" />

import type { AgentConfig, InvokeArgs, RespondArgs, AgentError } from 'strands:agent/api@0.1.0'
import type { Message as WitMessage, PromptInput } from 'strands:agent/messages@0.1.0'
import type {
ContentBlock as WitContentBlock,
Message as WitMessage,
PromptInput,
ToolResultBlock as WitToolResultBlock,
ToolResultContent as WitToolResultContent,
} from 'strands:agent/messages@0.1.0'
import type {
StreamEvent as WitStreamEvent,
StopEvent as WitStopEvent,
Expand Down Expand Up @@ -52,6 +58,8 @@ import type {
ToolChoice,
ModelStreamEvent,
ContentBlock,
ToolResultBlock,
ToolResultContent,
SaveLatestStrategy,
JSONValue,
} from '@strands-agents/sdk'
Expand Down Expand Up @@ -129,29 +137,62 @@ function mapMessage(message: Message): WitMessage {
} as WitMessage
}

/** Serialize a TS SDK ContentBlock to the WIT tagged-variant shape. */
function mapContentBlock(block: ContentBlock): import('strands:agent/messages@0.1.0').ContentBlock {
type WitBlock = import('strands:agent/messages@0.1.0').ContentBlock
// block.type is the SDK class discriminator; toJSON drops class identity but keeps fields.
/**
* Serialize a TS SDK ContentBlock to the WIT `content-block` tagged variant.
*
* Most SDK blocks `toJSON()` to `{<discriminator>: <inner-data>}` (e.g.
* `{toolUse: {...}}`). The matching WIT record is the inner shape only;
* the discriminator already lives in `tag`. We strip that outer wrapper
* here so the WIT marshaler sees the right fields. `text` and `json` are
* the exception: their toJSON is already the inner shape.
*/
function mapContentBlock(block: ContentBlock): WitContentBlock {
const payload = JSON.parse(JSON.stringify(block))
switch (block.type) {
case 'textBlock': return { tag: 'text', val: payload } as WitBlock
case 'toolUseBlock': return { tag: 'tool-use', val: payload } as WitBlock
case 'toolResultBlock': return { tag: 'tool-result', val: payload } as WitBlock
case 'reasoningBlock': return { tag: 'reasoning', val: payload } as WitBlock
case 'cachePointBlock': return { tag: 'cache-point', val: payload } as WitBlock
case 'imageBlock': return { tag: 'image', val: payload } as WitBlock
case 'videoBlock': return { tag: 'video', val: payload } as WitBlock
case 'documentBlock': return { tag: 'document', val: payload } as WitBlock
case 'citationsBlock': return { tag: 'citations', val: payload } as WitBlock
case 'guardContentBlock': return { tag: 'guard-content', val: payload } as WitBlock
case 'textBlock': return { tag: 'text', val: payload } as WitContentBlock
case 'toolUseBlock':
return { tag: 'tool-use', val: { ...payload.toolUse, input: JSON.stringify(payload.toolUse.input) } } as WitContentBlock
case 'toolResultBlock':
return { tag: 'tool-result', val: mapToolResultBlock(block) } as WitContentBlock
case 'reasoningBlock': return { tag: 'reasoning', val: payload.reasoning } as WitContentBlock
case 'cachePointBlock': return { tag: 'cache-point', val: payload.cachePoint } as WitContentBlock
case 'imageBlock': return { tag: 'image', val: payload.image } as WitContentBlock
case 'videoBlock': return { tag: 'video', val: payload.video } as WitContentBlock
case 'documentBlock': return { tag: 'document', val: payload.document } as WitContentBlock
case 'citationsBlock': return { tag: 'citations', val: payload.citations } as WitContentBlock
case 'guardContentBlock': return { tag: 'guard-content', val: payload.guardContent } as WitContentBlock
default: {
block satisfies never
throw new Error(`unknown content block: ${(block as { type: string }).type}`)
}
}
}

/** Serialize a TS SDK `ToolResultBlock` to the WIT `tool-result-block` record. */
function mapToolResultBlock(block: ToolResultBlock): WitToolResultBlock {
return {
toolUseId: block.toolUseId,
status: block.status,
content: block.content.map(mapToolResultContent),
}
}

/** Serialize a TS SDK `ToolResultContent` to the WIT `tool-result-content` tagged variant. */
function mapToolResultContent(block: ToolResultContent): WitToolResultContent {
const payload = JSON.parse(JSON.stringify(block))
switch (block.type) {
case 'textBlock': return { tag: 'text', val: payload } as WitToolResultContent
case 'jsonBlock': return { tag: 'json', val: payload } as WitToolResultContent
case 'imageBlock': return { tag: 'image', val: payload.image } as WitToolResultContent
case 'videoBlock': return { tag: 'video', val: payload.video } as WitToolResultContent
case 'documentBlock': return { tag: 'document', val: payload.document } as WitToolResultContent
default: {
block satisfies never
throw new Error(`unsupported tool-result-content type: ${(block as { type: string }).type}`)
}
}
}

//
// --- stream event mapping ------------------------------------------------
//
Expand Down Expand Up @@ -211,7 +252,7 @@ function mapEvent(event: AgentStreamEvent): WitStreamEvent | null {
toolUseId: event.toolUse.toolUseId,
input: JSON.stringify(event.toolUse.input ?? {}),
},
toolResult: mapContentBlock(event.result) as unknown as import('strands:agent/messages@0.1.0').ToolResultBlock,
toolResult: mapToolResultBlock(event.result),
error: event.error ? { tag: 'execution-failed', val: event.error.message } : undefined,
},
}
Expand All @@ -225,7 +266,7 @@ function mapEvent(event: AgentStreamEvent): WitStreamEvent | null {
case 'toolResultEvent':
return {
tag: 'tool-result-hook',
val: { toolResult: mapContentBlock(event.result) as unknown as import('strands:agent/messages@0.1.0').ToolResultBlock },
val: { toolResult: mapToolResultBlock(event.result) },
}
case 'toolStreamUpdateEvent':
return { tag: 'tool-update', val: { data: JSON.stringify(event.event.data ?? null) } }
Expand Down Expand Up @@ -365,8 +406,17 @@ function createTools(specs: ToolSpec[] | undefined): FunctionTool[] | undefined
case 'data':
// Streaming tool progress is not surfaced to the SDK caller today.
continue
case 'complete':
return value.val as unknown as JSONValue
case 'complete': {
// The host pushes WIT `tool-result-content` variant arms; the
// TS FunctionTool expects the single-key data shape that
// `toolResultContentFromData` accepts. text/json arms already
// carry that shape inline; other arms need an explicit wrap.
const content = (value.val as Array<{ tag: string; val: unknown }>).map((c) => {
if (c.tag === 'text' || c.tag === 'json') return c.val
return { [c.tag]: c.val }
})
return content as unknown as JSONValue
}
case 'error':
throw new Error(`tool ${spec.name} failed: ${value.val.tag}`)
}
Expand Down
Loading