Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 157 additions & 2 deletions agent-service/src/server.test.ts → agent-service/src/server.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
* under the License.
*/

import { beforeEach, describe, expect, test } from "bun:test";
import { buildApp, _resetAgentStoreForTests } from "./server";
import { beforeEach, describe, expect, spyOn, test } from "bun:test";
import { buildApp, start, _resetAgentStoreForTests, _getAgentForTests } from "./server";
import { WorkflowSystemMetadata } from "./agent/util/workflow-system-metadata";
import { env } from "./config/env";

const API = env.API_PREFIX;
Expand Down Expand Up @@ -249,3 +250,157 @@ describe(`PATCH ${API}/agents/:id/settings`, () => {
expect(reread.toolTimeoutSeconds).toBe(30);
});
});

describe("agent creation edge cases", () => {
test("rejects an empty modelType", async () => {
// The body schema accepts any string, so the handler's own guard runs.
const res = await postJson(`${API}/agents`, { modelType: "" }, { Authorization: `Bearer ${TOKEN}` });
expect(res.status).toBe(400);
expect((await readJson<{ error: string }>(res)).error).toContain("modelType");
});

test("applies initial settings supplied at creation time", async () => {
const res = await createAgent({ settings: { maxSteps: 9, toolTimeoutSeconds: 12 } });
expect(res.status).toBe(200);
const body = await readJson<{ settings: { maxSteps: number; toolTimeoutSeconds: number } }>(res);
expect(body.settings.maxSteps).toBe(9);
expect(body.settings.toolTimeoutSeconds).toBe(12);
});

test("creates the agent even when the workflow load fails (non-fatal)", async () => {
// retrieveWorkflow targets the (unavailable) dashboard service; the failure
// is caught and the agent is still created.
const res = await createAgent({ workflowId: 123 });
expect(res.status).toBe(200);
});

test("masks the delegate token in agent info", async () => {
const id = (await readJson<{ id: string }>(await createAgent())).id;
_getAgentForTests(id)!.setDelegateConfig({
userToken: "super-secret",
userInfo: { uid: 1, email: "tester@example.com" },
workflowId: 5,
workflowName: "My Flow",
computingUnitId: 2,
} as any);

const info = await readJson<{ delegate?: { userToken: string; workflowName: string } }>(
await getJson(`${API}/agents/${id}`)
);
expect(info.delegate?.userToken).toBe("***");
expect(info.delegate?.workflowName).toBe("My Flow");
});
});

describe("agent read routes", () => {
let id: string;
beforeEach(async () => {
id = (await readJson<{ id: string }>(await createAgent())).id;
});

test("GET /:id/react-steps returns steps and state", async () => {
const body = await readJson<{ steps: unknown[]; state: string }>(await getJson(`${API}/agents/${id}/react-steps`));
expect(Array.isArray(body.steps)).toBe(true);
expect(body.state).toBe("AVAILABLE");
});

test("GET /:id/system-info responds", async () => {
const res = await getJson(`${API}/agents/${id}/system-info`);
expect(res.status).toBe(200);
});

test("GET /:id/operator-types returns a list", async () => {
const res = await getJson(`${API}/agents/${id}/operator-types`);
expect(res.status).toBe(200);
expect(Array.isArray(await readJson(res))).toBe(true);
});

test("POST /:id/steps-by-operators returns steps", async () => {
const res = await postJson(`${API}/agents/${id}/steps-by-operators`, { operatorIds: [] });
expect(res.status).toBe(200);
expect(Array.isArray((await readJson<{ steps: unknown[] }>(res)).steps)).toBe(true);
});

test("GET /:id/operator-results maps the visible operator results", async () => {
const agent = _getAgentForTests(id)!;
(agent as any).getWorkflowResultState = () => ({
getAllVisible: () =>
new Map([
[
"op-1",
{
operatorInfo: {
state: "COMPLETED",
inputTuples: 1,
outputTuples: 2,
inputPortShapes: [],
result: [{ a: 1 }],
error: undefined,
warnings: [],
consoleLogs: [],
totalRowCount: 2,
resultStatistics: {},
},
},
],
]),
});

const body = await readJson<{ results: Record<string, { outputTuples: number; outputColumns: number }> }>(
await getJson(`${API}/agents/${id}/operator-results`)
);
expect(body.results["op-1"].outputTuples).toBe(2);
expect(body.results["op-1"].outputColumns).toBe(1);
});
});

describe("checkout route", () => {
test("broadcasts and survives a websocket whose send throws", async () => {
const id = (await readJson<{ id: string }>(await createAgent())).id;
const agent = _getAgentForTests(id)!;
(agent as any).checkout = () => true;
(agent as any).getAllSteps = () => [];
// A failing socket must be dropped inside broadcastToAgent, not crash the request.
agent.addWebsocket({
send: () => {
throw new Error("send failed");
},
} as any);

const res = await postJson(`${API}/agents/${id}/checkout`, { stepId: "step-1" });
expect(res.status).toBe(200);
expect((await readJson<{ headId: string }>(res)).headId).toBe("step-1");
});

test("returns 500 when the step cannot be found", async () => {
const id = (await readJson<{ id: string }>(await createAgent())).id;
(_getAgentForTests(id) as any).checkout = () => false;
const res = await postJson(`${API}/agents/${id}/checkout`, { stepId: "missing" });
expect(res.status).toBe(500);
});
});

describe("non-router routes", () => {
test("unknown routes fall through to the catch-all error handler", async () => {
const res = await getJson("/no-such-route");
expect(res.status).toBe(500);
});
});

describe("start()", () => {
test("boots a listening app and prints the startup banner", async () => {
const booted = await start();
expect(typeof booted.server?.port).toBe("number");
await booted.stop();
});

test("tolerates a metadata-initialization failure", async () => {
const spy = spyOn(WorkflowSystemMetadata, "initializeGlobal").mockImplementation(async () => {
throw new Error("metadata unavailable");
});
const booted = await start();
await booted.stop();
expect(spy).toHaveBeenCalled();
spy.mockRestore();
});
});
94 changes: 32 additions & 62 deletions agent-service/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import type {
AgentSettingsApi,
ReActStep,
} from "./types/agent";
import { OperatorResultSerializationMode } from "./types/agent";
import { AgentState, OperatorResultSerializationMode } from "./types/agent";
import type { WsClientRequest, WsServerMessage, WsServerSnapshotMessage, OperatorResultSummaryWs } from "./types/ws";

const agentStore = new Map<string, TexeraAgent>();
let agentCounter = 0;
Expand Down Expand Up @@ -410,37 +411,6 @@ const agentsRouter = new Elysia({ prefix: "/agents" })
}
);

interface WsMessage {
type: "message" | "stop";
content?: string;
messageSource?: "chat" | "feedback";
}

interface OperatorResultSummaryWs {
state: string;
inputTuples: number;
outputTuples: number;
inputPortShapes?: { portIndex: number; rows: number; columns: number }[];
outputColumns?: number;
error?: string;
warnings?: string[];
consoleLogCount?: number;
totalRowCount?: number;
sampleRecords?: Record<string, any>[];
resultStatistics?: Record<string, string>;
}

interface WsOutgoingMessage {
type: "step" | "state" | "error" | "complete" | "init" | "headChange";
step?: ReActStep;
state?: string;
error?: string;
steps?: ReActStep[];
headId?: string;
operatorResults?: Record<string, OperatorResultSummaryWs>;
workflowContent?: any;
}

function getOperatorResultSummaries(agent: TexeraAgent): Record<string, OperatorResultSummaryWs> {
const resultState = agent.getWorkflowResultState();
const visible = resultState.getAllVisible();
Expand All @@ -464,7 +434,7 @@ function getOperatorResultSummaries(agent: TexeraAgent): Record<string, Operator
return results;
}

function broadcastToAgent(agentId: string, message: WsOutgoingMessage): void {
function broadcastToAgent(agentId: string, message: WsServerMessage): void {
const agent = agentStore.get(agentId);
if (!agent) return;

Expand Down Expand Up @@ -504,14 +474,13 @@ export function buildApp() {

agent.addWebsocket(ws);

const initMessage: WsOutgoingMessage = {
type: "init",
const snapshotMessage: WsServerSnapshotMessage = {
type: "snapshot",
state: agent.getState(),
steps: agent.getAllSteps(),
headId: agent.getHead(),
operatorResults: getOperatorResultSummaries(agent),
};
ws.send(JSON.stringify(initMessage));
ws.send(JSON.stringify(snapshotMessage));
},

async message(ws, messageData) {
Expand All @@ -523,21 +492,23 @@ export function buildApp() {
return;
}

let msg: WsMessage;
let msg: WsClientRequest;
try {
msg = typeof messageData === "string" ? JSON.parse(messageData) : (messageData as WsMessage);
msg = typeof messageData === "string" ? JSON.parse(messageData) : (messageData as WsClientRequest);
} catch {
ws.send(JSON.stringify({ type: "error", error: "Invalid message format" }));
return;
}

if (msg.type === "stop") {
agent.stop();
broadcastToAgent(agentId, { type: "state", state: "STOPPING" });
if (msg.type === "command") {
if (msg.commandType === "stop") {
agent.stop();
broadcastToAgent(agentId, { type: "status", state: AgentState.STOPPING });
}
return;
}

if (msg.type === "message") {
if (msg.type === "prompt") {
if (!msg.content || typeof msg.content !== "string") {
ws.send(JSON.stringify({ type: "error", error: "Message content is required" }));
return;
Expand All @@ -546,15 +517,10 @@ export function buildApp() {
wsLog.info({ agentId, preview: msg.content.substring(0, 50) }, "received message");

agent.setStepCallback((step: ReActStep) => {
const hasToolCalls = step.toolCalls && step.toolCalls.length > 0;
broadcastToAgent(agentId, {
type: "step",
step,
...(hasToolCalls ? { operatorResults: getOperatorResultSummaries(agent) } : {}),
});
broadcastToAgent(agentId, { type: "step", step });
});

broadcastToAgent(agentId, { type: "state", state: "GENERATING" });
broadcastToAgent(agentId, { type: "status", state: AgentState.GENERATING });

try {
const result = await agent.sendMessage(msg.content, msg.messageSource);
Expand All @@ -567,16 +533,16 @@ export function buildApp() {
broadcastToAgent(agentId, { type: "step", step: lastStep });
}

broadcastToAgent(agentId, {
type: "complete",
state: agent.getState(),
operatorResults: getOperatorResultSummaries(agent),
});

wsLog.info({ agentId, steps: result.messages.length }, "agent run complete");
} catch (error: any) {
agent.setStepCallback(null);
broadcastToAgent(agentId, { type: "error", error: error.message });
} finally {
// The run is over (success or failure) and TexeraAgent.sendMessage has
// reset the agent to its resting state (AVAILABLE) in its own finally.
// This status frame is the run-end signal (it also unsticks the client
// from GENERATING after errors).
broadcastToAgent(agentId, { type: "status", state: agent.getState() });
}
}
},
Expand Down Expand Up @@ -605,6 +571,12 @@ export function _resetAgentStoreForTests(): void {
agentCounter = 0;
}

// Look up an agent instance by id. Used by tests to stub agent behavior (e.g.
// `sendMessage`) when exercising the WebSocket handlers.
export function _getAgentForTests(agentId: string): TexeraAgent | undefined {
return agentStore.get(agentId);
}

function printStartupMessage(app: ReturnType<typeof buildApp>) {
const LINE = "=".repeat(60);
console.log(LINE);
Expand All @@ -630,9 +602,9 @@ function printStartupMessage(app: ReturnType<typeof buildApp>) {
for (const route of wsRoutes) {
console.log(` WS ${route.path}`);
}
console.log(" Send: { type: 'message', content: '...' }");
console.log(" Send: { type: 'stop' }");
console.log(" Recv: { type: 'step' | 'state' | 'complete' | 'error' | 'init', ... }");
console.log(" Send: { type: 'prompt', content: '...' }");
console.log(" Send: { type: 'command', commandType: 'stop' }");
console.log(" Recv: { type: 'snapshot' | 'step' | 'status' | 'error' | 'headChange', ... }");
}

console.log("");
Expand Down Expand Up @@ -665,6 +637,4 @@ export async function start() {

// Run the server only when this file is the entry point, not when it is
// imported by tests or other modules.
if (import.meta.main) {
start();
}
if (import.meta.main) start();
Loading
Loading