diff --git a/.changeset/cruel-corners-feel.md b/.changeset/cruel-corners-feel.md new file mode 100644 index 000000000..8bf9ab94f --- /dev/null +++ b/.changeset/cruel-corners-feel.md @@ -0,0 +1,8 @@ +--- +"@workflow/web-shared": patch +"@workflow/swc-plugin": patch +"@workflow/world": patch +"@workflow/core": patch +--- + +Add support for closure scope vars in step functions diff --git a/packages/core/e2e/e2e.test.ts b/packages/core/e2e/e2e.test.ts index 27cfaeba3..72991d163 100644 --- a/packages/core/e2e/e2e.test.ts +++ b/packages/core/e2e/e2e.test.ts @@ -737,4 +737,18 @@ describe('e2e', () => { expect(stepCompletedEvents).toHaveLength(1); } ); + + test( + 'closureVariableWorkflow - nested step functions with closure variables', + { timeout: 60_000 }, + async () => { + // This workflow uses a nested step function that references closure variables + // from the parent workflow scope (multiplier, prefix, baseValue) + const run = await triggerWorkflow('closureVariableWorkflow', [7]); + const returnValue = await getWorkflowReturnValue(run.runId); + + // Expected: baseValue (7) * multiplier (3) = 21, prefixed with "Result: " + expect(returnValue).toBe('Result: 21'); + } + ); }); diff --git a/packages/core/src/global.ts b/packages/core/src/global.ts index 1685af0cb..3d4d0eda5 100644 --- a/packages/core/src/global.ts +++ b/packages/core/src/global.ts @@ -5,6 +5,7 @@ export interface StepInvocationQueueItem { correlationId: string; stepName: string; args: Serializable[]; + closureVars?: Record; } export interface HookInvocationQueueItem { diff --git a/packages/core/src/private.ts b/packages/core/src/private.ts index 710ca463c..2d284e467 100644 --- a/packages/core/src/private.ts +++ b/packages/core/src/private.ts @@ -29,6 +29,12 @@ export function getStepFunction(stepId: string): StepFunction | undefined { return registeredSteps.get(stepId); } +/** + * Get closure variables for the current step function + * @internal + */ +export { __private_getClosureVars } from './step/get-closure-vars.js'; + export interface WorkflowOrchestratorContext { globalThis: typeof globalThis; eventsConsumer: EventsConsumer; diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts index edd112866..edf5e8a8e 100644 --- a/packages/core/src/runtime.ts +++ b/packages/core/src/runtime.ts @@ -389,8 +389,11 @@ export function workflowEntrypoint(workflowCode: string) { if (queueItem.type === 'step') { // Handle step operations const ops: Promise[] = []; - const dehydratedArgs = dehydrateStepArguments( - queueItem.args, + const dehydratedInput = dehydrateStepArguments( + { + args: queueItem.args, + closureVars: queueItem.closureVars, + }, err.globalThis ); @@ -398,7 +401,7 @@ export function workflowEntrypoint(workflowCode: string) { const step = await world.steps.create(runId, { stepId: queueItem.correlationId, stepName: queueItem.stepName, - input: dehydratedArgs as Serializable[], + input: dehydratedInput as Serializable, }); waitUntil( @@ -678,9 +681,15 @@ export const stepEntrypoint = `Step "${stepId}" has no "startedAt" timestamp` ); } - // Hydrate the step input arguments + // Hydrate the step input arguments and closure variables const ops: Promise[] = []; - const args = hydrateStepArguments(step.input, ops, workflowRunId); + const hydratedInput = hydrateStepArguments( + step.input, + ops, + workflowRunId + ); + + const args = hydratedInput.args; span?.setAttributes({ ...Attribute.StepArgumentsCount(args.length), @@ -703,8 +712,9 @@ export const stepEntrypoint = : `http://localhost:${port ?? 3000}`, }, ops, + closureVars: hydratedInput.closureVars, }, - () => stepFn(...args) + () => stepFn.apply(null, args) ); // NOTE: None of the code from this point is guaranteed to run diff --git a/packages/core/src/step.test.ts b/packages/core/src/step.test.ts index 6bf2957c8..8d2bdd6d3 100644 --- a/packages/core/src/step.test.ts +++ b/packages/core/src/step.test.ts @@ -199,4 +199,75 @@ describe('createUseStep', () => { await myStepFunction(); expect(ctx.onWorkflowError).not.toHaveBeenCalled(); }); + + it('should capture closure variables when provided', async () => { + const ctx = setupWorkflowContext([ + { + eventId: 'evnt_0', + runId: 'wrun_123', + eventType: 'step_completed', + correlationId: 'step_01K11TFZ62YS0YYFDQ3E8B9YCV', + eventData: { + result: ['Result: 42'], + }, + createdAt: new Date(), + }, + ]); + + const useStep = createUseStep(ctx); + const count = 42; + const prefix = 'Result: '; + + // Create step with closure variables function + const calculate = useStep('calculate', () => ({ count, prefix })); + + // Call the step + const result = await calculate(); + + // Verify result + expect(result).toBe('Result: 42'); + + // Verify closure variables were added to invocation queue + expect(ctx.invocationsQueue).toHaveLength(1); + expect(ctx.invocationsQueue[0]).toMatchObject({ + type: 'step', + stepName: 'calculate', + args: [], + closureVars: { count: 42, prefix: 'Result: ' }, + }); + }); + + it('should handle empty closure variables', async () => { + const ctx = setupWorkflowContext([ + { + eventId: 'evnt_0', + runId: 'wrun_123', + eventType: 'step_completed', + correlationId: 'step_01K11TFZ62YS0YYFDQ3E8B9YCV', + eventData: { + result: [5], + }, + createdAt: new Date(), + }, + ]); + + const useStep = createUseStep(ctx); + + // Create step without closure variables + const add = useStep('add'); + + // Call the step + const result = await add(2, 3); + + // Verify result + expect(result).toBe(5); + + // Verify empty closure variables were added to invocation queue + expect(ctx.invocationsQueue).toHaveLength(1); + expect(ctx.invocationsQueue[0]).toMatchObject({ + type: 'step', + stepName: 'add', + args: [2, 3], + }); + }); }); diff --git a/packages/core/src/step.ts b/packages/core/src/step.ts index 875308010..bf0fc1cc0 100644 --- a/packages/core/src/step.ts +++ b/packages/core/src/step.ts @@ -1,7 +1,7 @@ import { FatalError, WorkflowRuntimeError } from '@workflow/errors'; import { withResolvers } from '@workflow/utils'; import { EventConsumerResult } from './events-consumer.js'; -import { WorkflowSuspension } from './global.js'; +import { type StepInvocationQueueItem, WorkflowSuspension } from './global.js'; import { stepLogger } from './logger.js'; import type { WorkflowOrchestratorContext } from './private.js'; import type { Serializable } from './schemas.js'; @@ -9,18 +9,28 @@ import { hydrateStepReturnValue } from './serialization.js'; export function createUseStep(ctx: WorkflowOrchestratorContext) { return function useStep( - stepName: string + stepName: string, + closureVarsFn?: () => Record ) { const stepFunction = (...args: Args): Promise => { const { promise, resolve, reject } = withResolvers(); const correlationId = `step_${ctx.generateUlid()}`; - ctx.invocationsQueue.push({ + + const queueItem: StepInvocationQueueItem = { type: 'step', correlationId, stepName, args, - }); + }; + + // Invoke the closure variables function to get the closure scope + const closureVars = closureVarsFn?.(); + if (closureVars) { + queueItem.closureVars = closureVars; + } + + ctx.invocationsQueue.push(queueItem); // Track whether we've already seen a "step_started" event for this step. // This is important because after a retryable failure, the step moves back to diff --git a/packages/core/src/step/context-storage.ts b/packages/core/src/step/context-storage.ts index 4749f2ff1..e9c1abbf3 100644 --- a/packages/core/src/step/context-storage.ts +++ b/packages/core/src/step/context-storage.ts @@ -6,4 +6,5 @@ export const contextStorage = /* @__PURE__ */ new AsyncLocalStorage<{ stepMetadata: StepMetadata; workflowMetadata: WorkflowMetadata; ops: Promise[]; + closureVars?: Record; }>(); diff --git a/packages/core/src/step/get-closure-vars.ts b/packages/core/src/step/get-closure-vars.ts new file mode 100644 index 000000000..a5d1d5193 --- /dev/null +++ b/packages/core/src/step/get-closure-vars.ts @@ -0,0 +1,18 @@ +import { contextStorage } from './context-storage.js'; + +/** + * Returns the closure variables for the current step function. + * This is an internal function used by the SWC transform to access + * variables from the parent workflow scope. + * + * @internal + */ +export function __private_getClosureVars(): Record { + const ctx = contextStorage.getStore(); + if (!ctx) { + throw new Error( + 'Closure variables can only be accessed inside a step function' + ); + } + return ctx.closureVars || {}; +} diff --git a/packages/core/src/workflow.test.ts b/packages/core/src/workflow.test.ts index 93d868ede..e587ba3f9 100644 --- a/packages/core/src/workflow.test.ts +++ b/packages/core/src/workflow.test.ts @@ -2304,4 +2304,96 @@ describe('runWorkflow', () => { ); }); }); + + describe('closure variables', () => { + it('should serialize and deserialize closure variables for nested step functions', async () => { + let error: Error | undefined; + try { + const ops: Promise[] = []; + const workflowRun: WorkflowRun = { + runId: 'test-run-123', + workflowName: 'workflow', + status: 'running', + input: dehydrateWorkflowArguments([], ops), + createdAt: new Date('2024-01-01T00:00:00.000Z'), + updatedAt: new Date('2024-01-01T00:00:00.000Z'), + startedAt: new Date('2024-01-01T00:00:00.000Z'), + deploymentId: 'test-deployment', + }; + + const events: Event[] = []; + + await runWorkflow( + `const useStep = globalThis[Symbol.for("WORKFLOW_USE_STEP")]; + async function workflow() { + const multiplier = 3; + const prefix = 'Result: '; + const calculate = useStep('step//input.js//_anonymousStep0', () => ({ multiplier, prefix })); + const result = await calculate(7); + return result; + }${getWorkflowTransformCode('workflow')}`, + workflowRun, + events + ); + } catch (err) { + error = err as Error; + } + + // Should suspend to create the step + assert(error); + expect(error.name).toEqual('WorkflowSuspension'); + expect((error as WorkflowSuspension).steps).toHaveLength(1); + + const step = (error as WorkflowSuspension).steps[0]; + expect(step).toMatchObject({ + type: 'step', + stepName: 'step//input.js//_anonymousStep0', + args: [7], + closureVars: { multiplier: 3, prefix: 'Result: ' }, + }); + }); + + it('should handle step functions without closure variables', async () => { + let error: Error | undefined; + try { + const ops: Promise[] = []; + const workflowRun: WorkflowRun = { + runId: 'test-run-123', + workflowName: 'workflow', + status: 'running', + input: dehydrateWorkflowArguments([], ops), + createdAt: new Date('2024-01-01T00:00:00.000Z'), + updatedAt: new Date('2024-01-01T00:00:00.000Z'), + startedAt: new Date('2024-01-01T00:00:00.000Z'), + deploymentId: 'test-deployment', + }; + + const events: Event[] = []; + + await runWorkflow( + `const add = globalThis[Symbol.for("WORKFLOW_USE_STEP")]("add"); + async function workflow() { + const result = await add(5, 10); + return result; + }${getWorkflowTransformCode('workflow')}`, + workflowRun, + events + ); + } catch (err) { + error = err as Error; + } + + // Should suspend to create the step + assert(error); + expect(error.name).toEqual('WorkflowSuspension'); + expect((error as WorkflowSuspension).steps).toHaveLength(1); + + const step = (error as WorkflowSuspension).steps[0]; + expect(step).toMatchObject({ + type: 'step', + stepName: 'add', + args: [5, 10], + }); + }); + }); }); diff --git a/packages/swc-plugin-workflow/transform/src/lib.rs b/packages/swc-plugin-workflow/transform/src/lib.rs index ac490f47a..5f7ac2a0d 100644 --- a/packages/swc-plugin-workflow/transform/src/lib.rs +++ b/packages/swc-plugin-workflow/transform/src/lib.rs @@ -1,7 +1,7 @@ mod naming; use serde::Deserialize; -use std::collections::{HashSet, HashMap}; +use std::collections::{HashMap, HashSet}; use swc_core::{ common::{DUMMY_SP, SyntaxContext, errors::HANDLER}, ecma::{ @@ -188,8 +188,8 @@ pub struct StepTransform { // (parent_var_name, prop_name, arrow_expr, span) object_property_step_functions: Vec<(String, String, ArrowExpr, swc_core::common::Span)>, // Track nested step functions inside workflow functions for hoisting in step mode - // (fn_name, fn_expr, span) - nested_step_functions: Vec<(String, FnExpr, swc_core::common::Span)>, + // (fn_name, fn_expr, span, closure_vars, was_arrow) + nested_step_functions: Vec<(String, FnExpr, swc_core::common::Span, Vec, bool)>, // Counter for anonymous function names #[allow(dead_code)] anonymous_fn_counter: usize, @@ -199,6 +199,8 @@ pub struct StepTransform { // Current context: variable name being processed when visiting object properties #[allow(dead_code)] current_var_context: Option, + // Track module-level imports to exclude from closure variables + module_imports: HashSet, } // Structure to track variable names and their access patterns @@ -262,6 +264,419 @@ impl TryFrom<&Expr> for Name { } } +// Visitor to collect closure variables from a nested step function +struct ClosureVariableCollector { + closure_vars: HashSet, + local_vars: HashSet, + params: HashSet, +} + +impl ClosureVariableCollector { + fn new() -> Self { + Self { + closure_vars: HashSet::new(), + local_vars: HashSet::new(), + params: HashSet::new(), + } + } + + fn collect_from_function(function: &Function, module_imports: &HashSet) -> Vec { + let mut collector = Self::new(); + + // Add module-level imports to local_vars so they're not considered closure vars + collector.local_vars.extend(module_imports.iter().cloned()); + + // Collect parameters + for param in &function.params { + collector.collect_param_names(¶m.pat); + } + + // Visit function body to collect references and declarations + if let Some(body) = &function.body { + collector.collect_from_block_stmt(body); + } + + // Return closure vars sorted for deterministic output + let mut vars: Vec = collector.closure_vars.into_iter().collect(); + vars.sort(); + vars + } + + fn collect_from_arrow_expr(arrow: &ArrowExpr, module_imports: &HashSet) -> Vec { + let mut collector = Self::new(); + + // Add module-level imports to local_vars so they're not considered closure vars + collector.local_vars.extend(module_imports.iter().cloned()); + + // Collect parameters + for param in &arrow.params { + collector.collect_param_names(param); + } + + // Visit arrow body + match &*arrow.body { + BlockStmtOrExpr::BlockStmt(block) => { + collector.collect_from_block_stmt(block); + } + BlockStmtOrExpr::Expr(expr) => { + collector.collect_from_expr(expr); + } + } + + // Return closure vars sorted for deterministic output + let mut vars: Vec = collector.closure_vars.into_iter().collect(); + vars.sort(); + vars + } + + fn collect_param_names(&mut self, pat: &Pat) { + match pat { + Pat::Ident(ident) => { + self.params.insert(ident.id.sym.to_string()); + } + Pat::Array(array) => { + for elem in array.elems.iter().flatten() { + self.collect_param_names(elem); + } + } + Pat::Object(obj) => { + for prop in &obj.props { + match prop { + ObjectPatProp::KeyValue(kv) => { + self.collect_param_names(&kv.value); + } + ObjectPatProp::Assign(assign) => { + self.params.insert(assign.key.id.sym.to_string()); + } + ObjectPatProp::Rest(rest) => { + self.collect_param_names(&rest.arg); + } + } + } + } + Pat::Rest(rest) => { + self.collect_param_names(&rest.arg); + } + Pat::Assign(assign) => { + self.collect_param_names(&assign.left); + } + _ => {} + } + } + + fn collect_from_block_stmt(&mut self, block: &BlockStmt) { + for stmt in &block.stmts { + self.collect_from_stmt(stmt); + } + } + + fn collect_from_stmt(&mut self, stmt: &Stmt) { + match stmt { + Stmt::Decl(decl) => { + match decl { + Decl::Var(var_decl) => { + for declarator in &var_decl.decls { + // Collect the declared variable names + self.collect_declared_names(&declarator.name); + // Then collect references in the initializer + if let Some(init) = &declarator.init { + self.collect_from_expr(init); + } + } + } + Decl::Fn(fn_decl) => { + self.local_vars.insert(fn_decl.ident.sym.to_string()); + // Don't visit nested function bodies for closure detection + } + _ => {} + } + } + Stmt::Expr(expr_stmt) => { + self.collect_from_expr(&expr_stmt.expr); + } + Stmt::If(if_stmt) => { + self.collect_from_expr(&if_stmt.test); + self.collect_from_stmt(&if_stmt.cons); + if let Some(alt) = &if_stmt.alt { + self.collect_from_stmt(alt); + } + } + Stmt::Return(ret_stmt) => { + if let Some(arg) = &ret_stmt.arg { + self.collect_from_expr(arg); + } + } + Stmt::Block(block) => { + self.collect_from_block_stmt(block); + } + Stmt::For(for_stmt) => { + if let Some(init) = &for_stmt.init { + match init { + VarDeclOrExpr::VarDecl(var_decl) => { + for declarator in &var_decl.decls { + self.collect_declared_names(&declarator.name); + if let Some(init) = &declarator.init { + self.collect_from_expr(init); + } + } + } + VarDeclOrExpr::Expr(expr) => { + self.collect_from_expr(expr); + } + } + } + if let Some(test) = &for_stmt.test { + self.collect_from_expr(test); + } + if let Some(update) = &for_stmt.update { + self.collect_from_expr(update); + } + self.collect_from_stmt(&for_stmt.body); + } + Stmt::While(while_stmt) => { + self.collect_from_expr(&while_stmt.test); + self.collect_from_stmt(&while_stmt.body); + } + _ => {} + } + } + + fn collect_declared_names(&mut self, pat: &Pat) { + match pat { + Pat::Ident(ident) => { + self.local_vars.insert(ident.id.sym.to_string()); + } + Pat::Array(array) => { + for elem in array.elems.iter().flatten() { + self.collect_declared_names(elem); + } + } + Pat::Object(obj) => { + for prop in &obj.props { + match prop { + ObjectPatProp::KeyValue(kv) => { + self.collect_declared_names(&kv.value); + } + ObjectPatProp::Assign(assign) => { + self.local_vars.insert(assign.key.id.sym.to_string()); + } + ObjectPatProp::Rest(rest) => { + self.collect_declared_names(&rest.arg); + } + } + } + } + Pat::Rest(rest) => { + self.collect_declared_names(&rest.arg); + } + Pat::Assign(assign) => { + self.collect_declared_names(&assign.left); + } + _ => {} + } + } + + fn collect_from_expr(&mut self, expr: &Expr) { + match expr { + Expr::Ident(ident) => { + let name = ident.sym.to_string(); + // Only add as closure var if it's not a parameter or local var + if !self.params.contains(&name) && !self.local_vars.contains(&name) { + // Filter out known globals + if !is_global_identifier(&name) { + self.closure_vars.insert(name); + } + } + } + Expr::Call(call) => { + if let Callee::Expr(callee) = &call.callee { + self.collect_from_expr(callee); + } + for arg in &call.args { + self.collect_from_expr(&arg.expr); + } + } + Expr::Member(member) => { + self.collect_from_expr(&member.obj); + } + Expr::Bin(bin) => { + self.collect_from_expr(&bin.left); + self.collect_from_expr(&bin.right); + } + Expr::Unary(unary) => { + self.collect_from_expr(&unary.arg); + } + Expr::Cond(cond) => { + self.collect_from_expr(&cond.test); + self.collect_from_expr(&cond.cons); + self.collect_from_expr(&cond.alt); + } + Expr::Array(array) => { + for elem in array.elems.iter().flatten() { + self.collect_from_expr(&elem.expr); + } + } + Expr::Object(obj) => { + for prop in &obj.props { + match prop { + PropOrSpread::Prop(prop) => { + match &**prop { + Prop::KeyValue(kv) => { + self.collect_from_expr(&kv.value); + } + Prop::Method(_method) => { + // Don't visit nested method bodies + } + _ => {} + } + } + PropOrSpread::Spread(spread) => { + self.collect_from_expr(&spread.expr); + } + } + } + } + Expr::Paren(paren) => { + self.collect_from_expr(&paren.expr); + } + Expr::Tpl(tpl) => { + for expr in &tpl.exprs { + self.collect_from_expr(expr); + } + } + Expr::TaggedTpl(tagged) => { + self.collect_from_expr(&tagged.tag); + for expr in &tagged.tpl.exprs { + self.collect_from_expr(expr); + } + } + Expr::Arrow(_arrow) => { + // Don't visit nested arrow function bodies for closure detection + } + Expr::Fn(_) => { + // Don't visit nested function bodies for closure detection + } + Expr::Assign(assign) => { + self.collect_from_expr(&assign.right); + // Also check the left side for references (e.g., obj.prop = value) + match &assign.left { + AssignTarget::Simple(simple) => { + match simple { + SimpleAssignTarget::Ident(ident) => { + // This is an assignment to a variable, check if it's a closure var + self.collect_from_ident_binding(&ident.id); + } + SimpleAssignTarget::Member(member) => { + self.collect_from_expr(&member.obj); + } + _ => {} + } + } + _ => {} + } + } + Expr::Update(update) => { + self.collect_from_expr(&update.arg); + } + Expr::Await(await_expr) => { + self.collect_from_expr(&await_expr.arg); + } + _ => {} + } + } + + fn collect_from_ident_binding(&mut self, ident: &Ident) { + let name = ident.sym.to_string(); + if !self.params.contains(&name) && !self.local_vars.contains(&name) { + if !is_global_identifier(&name) { + self.closure_vars.insert(name); + } + } + } +} + +fn is_global_identifier(name: &str) -> bool { + matches!( + name, + "console" + | "process" + | "global" + | "globalThis" + | "window" + | "document" + | "Array" + | "Object" + | "String" + | "Number" + | "Boolean" + | "Date" + | "Math" + | "JSON" + | "Promise" + | "Symbol" + | "Error" + | "TypeError" + | "ReferenceError" + | "SyntaxError" + | "RegExp" + | "Map" + | "Set" + | "WeakMap" + | "WeakSet" + | "parseInt" + | "parseFloat" + | "isNaN" + | "isFinite" + | "encodeURI" + | "decodeURI" + | "encodeURIComponent" + | "decodeURIComponent" + | "undefined" + | "null" + | "true" + | "false" + | "NaN" + | "Infinity" + | "setTimeout" + | "setInterval" + | "clearTimeout" + | "clearInterval" + | "fetch" + | "Response" + | "Request" + | "Headers" + | "URL" + | "URLSearchParams" + | "TextEncoder" + | "TextDecoder" + | "Buffer" + | "Uint8Array" + | "Int8Array" + | "Uint16Array" + | "Int16Array" + | "Uint32Array" + | "Int32Array" + | "Float32Array" + | "Float64Array" + | "BigInt" + | "BigInt64Array" + | "BigUint64Array" + | "DataView" + | "ArrayBuffer" + | "SharedArrayBuffer" + | "Atomics" + | "Proxy" + | "Reflect" + | "Intl" + | "WebAssembly" + | "require" + | "module" + | "exports" + | "__dirname" + | "__filename" + ) +} + impl StepTransform { fn process_stmt(&mut self, stmt: &mut Stmt) { match stmt { @@ -287,6 +702,14 @@ impl StepTransform { // Clone the function and remove the directive before hoisting let mut cloned_function = fn_decl.function.clone(); self.remove_use_step_directive(&mut cloned_function.body); + + // Collect closure variables + let closure_vars = + ClosureVariableCollector::collect_from_function( + &cloned_function, + &self.module_imports, + ); + let fn_expr = FnExpr { ident: Some(fn_decl.ident.clone()), function: cloned_function, @@ -295,6 +718,8 @@ impl StepTransform { fn_name.clone(), fn_expr, fn_decl.function.span, + closure_vars, + false, // Regular function, not arrow )); *stmt = Stmt::Empty(EmptyStmt { span: DUMMY_SP }); return; @@ -305,7 +730,15 @@ impl StepTransform { fn_decl.function.span, false, ); - let proxy_ref = self.create_step_proxy_reference(&step_id); + + // Collect closure variables + let closure_vars = + ClosureVariableCollector::collect_from_function( + &fn_decl.function, + &self.module_imports, + ); + let proxy_ref = + self.create_step_proxy_reference(&step_id, &closure_vars); let var_decl = Decl::Var(Box::new(VarDecl { span: DUMMY_SP, @@ -467,6 +900,7 @@ impl StepTransform { anonymous_fn_counter: 0, object_property_workflow_conversions: Vec::new(), current_var_context: None, + module_imports: HashSet::new(), } } @@ -498,12 +932,12 @@ impl StepTransform { fn generate_unique_name(&self, base_name: &str) -> String { let mut name = base_name.to_string(); let mut counter = 0; - + while self.declared_identifiers.contains(&name) { counter += 1; name = format!("{}${}", base_name, counter); } - + name } @@ -511,10 +945,27 @@ impl StepTransform { fn collect_declared_identifiers(&mut self, items: &[ModuleItem]) { for item in items { match item { - ModuleItem::Stmt(Stmt::Decl(decl)) => { - match decl { + ModuleItem::Stmt(Stmt::Decl(decl)) => match decl { + Decl::Fn(fn_decl) => { + self.declared_identifiers + .insert(fn_decl.ident.sym.to_string()); + } + Decl::Var(var_decl) => { + for declarator in &var_decl.decls { + self.collect_idents_from_pat(&declarator.name); + } + } + Decl::Class(class_decl) => { + self.declared_identifiers + .insert(class_decl.ident.sym.to_string()); + } + _ => {} + }, + ModuleItem::ModuleDecl(module_decl) => match module_decl { + ModuleDecl::ExportDecl(export_decl) => match &export_decl.decl { Decl::Fn(fn_decl) => { - self.declared_identifiers.insert(fn_decl.ident.sym.to_string()); + self.declared_identifiers + .insert(fn_decl.ident.sym.to_string()); } Decl::Var(var_decl) => { for declarator in &var_decl.decls { @@ -522,62 +973,44 @@ impl StepTransform { } } Decl::Class(class_decl) => { - self.declared_identifiers.insert(class_decl.ident.sym.to_string()); + self.declared_identifiers + .insert(class_decl.ident.sym.to_string()); } _ => {} - } - } - ModuleItem::ModuleDecl(module_decl) => { - match module_decl { - ModuleDecl::ExportDecl(export_decl) => { - match &export_decl.decl { - Decl::Fn(fn_decl) => { - self.declared_identifiers.insert(fn_decl.ident.sym.to_string()); - } - Decl::Var(var_decl) => { - for declarator in &var_decl.decls { - self.collect_idents_from_pat(&declarator.name); - } - } - Decl::Class(class_decl) => { - self.declared_identifiers.insert(class_decl.ident.sym.to_string()); - } - _ => {} + }, + ModuleDecl::ExportDefaultDecl(default_decl) => match &default_decl.decl { + DefaultDecl::Fn(fn_expr) => { + if let Some(ident) = &fn_expr.ident { + self.declared_identifiers.insert(ident.sym.to_string()); } } - ModuleDecl::ExportDefaultDecl(default_decl) => { - match &default_decl.decl { - DefaultDecl::Fn(fn_expr) => { - if let Some(ident) = &fn_expr.ident { - self.declared_identifiers.insert(ident.sym.to_string()); - } - } - DefaultDecl::Class(class_expr) => { - if let Some(ident) = &class_expr.ident { - self.declared_identifiers.insert(ident.sym.to_string()); - } - } - _ => {} + DefaultDecl::Class(class_expr) => { + if let Some(ident) = &class_expr.ident { + self.declared_identifiers.insert(ident.sym.to_string()); } } - ModuleDecl::Import(import_decl) => { - for specifier in &import_decl.specifiers { - match specifier { - ImportSpecifier::Named(named) => { - self.declared_identifiers.insert(named.local.sym.to_string()); - } - ImportSpecifier::Default(default) => { - self.declared_identifiers.insert(default.local.sym.to_string()); - } - ImportSpecifier::Namespace(namespace) => { - self.declared_identifiers.insert(namespace.local.sym.to_string()); - } + _ => {} + }, + ModuleDecl::Import(import_decl) => { + for specifier in &import_decl.specifiers { + match specifier { + ImportSpecifier::Named(named) => { + self.declared_identifiers + .insert(named.local.sym.to_string()); + } + ImportSpecifier::Default(default) => { + self.declared_identifiers + .insert(default.local.sym.to_string()); + } + ImportSpecifier::Namespace(namespace) => { + self.declared_identifiers + .insert(namespace.local.sym.to_string()); } } } - _ => {} } - } + _ => {} + }, _ => {} } } @@ -1309,11 +1742,73 @@ impl StepTransform { } } + // Convert a FnExpr back to ArrowExpr (for hoisting arrow functions) + fn convert_fn_expr_to_arrow(&self, fn_expr: &FnExpr) -> ArrowExpr { + let body = if let Some(block) = &fn_expr.function.body { + // Check if body is a single return statement - can be simplified to expression + if block.stmts.len() == 1 { + if let Stmt::Return(ret) = &block.stmts[0] { + if let Some(arg) = &ret.arg { + // Single return statement - use expression body + Box::new(BlockStmtOrExpr::Expr(arg.clone())) + } else { + // return with no value - keep as block + Box::new(BlockStmtOrExpr::BlockStmt(block.clone())) + } + } else { + Box::new(BlockStmtOrExpr::BlockStmt(block.clone())) + } + } else { + Box::new(BlockStmtOrExpr::BlockStmt(block.clone())) + } + } else { + Box::new(BlockStmtOrExpr::BlockStmt(BlockStmt { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + stmts: vec![], + })) + }; + + ArrowExpr { + span: fn_expr.function.span, + ctxt: SyntaxContext::empty(), + params: fn_expr + .function + .params + .iter() + .map(|p| p.pat.clone()) + .collect(), + body, + is_async: fn_expr.function.is_async, + is_generator: fn_expr.function.is_generator, + type_params: fn_expr.function.type_params.clone(), + return_type: fn_expr.function.return_type.clone(), + } + } + // Generate the import for registerStepFunction (step mode) - fn create_register_import(&self) -> ModuleItem { - ModuleItem::ModuleDecl(ModuleDecl::Import(ImportDecl { - span: DUMMY_SP, - specifiers: vec![ImportSpecifier::Named(ImportNamedSpecifier { + fn create_private_imports( + &self, + include_register: bool, + include_closure_vars: bool, + ) -> ModuleItem { + let mut specifiers = vec![]; + + if include_closure_vars { + specifiers.push(ImportSpecifier::Named(ImportNamedSpecifier { + span: DUMMY_SP, + local: Ident::new( + "__private_getClosureVars".into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + imported: None, + is_type_only: false, + })); + } + + if include_register { + specifiers.push(ImportSpecifier::Named(ImportNamedSpecifier { span: DUMMY_SP, local: Ident::new( "registerStepFunction".into(), @@ -1322,7 +1817,12 @@ impl StepTransform { ), imported: None, is_type_only: false, - })], + })); + } + + ModuleItem::ModuleDecl(ModuleDecl::Import(ImportDecl { + span: DUMMY_SP, + specifiers, src: Box::new(Str { span: DUMMY_SP, value: "workflow/internal/private".into(), @@ -1334,8 +1834,51 @@ impl StepTransform { })) } - // Create a proxy reference: globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step_id") (workflow mode) - fn create_step_proxy_reference(&self, step_id: &str) -> Expr { + // Create a proxy reference: globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step_id", closure_fn) (workflow mode) + fn create_step_proxy_reference(&self, step_id: &str, closure_vars: &[String]) -> Expr { + let mut args = vec![ExprOrSpread { + spread: None, + expr: Box::new(Expr::Lit(Lit::Str(Str { + span: DUMMY_SP, + value: step_id.into(), + raw: None, + }))), + }]; + + // If there are closure variables, add them as a second argument + if !closure_vars.is_empty() { + // Create arrow function: () => ({ var1, var2 }) + let closure_obj = Expr::Object(ObjectLit { + span: DUMMY_SP, + props: closure_vars + .iter() + .map(|var_name| { + PropOrSpread::Prop(Box::new(Prop::Shorthand(Ident::new( + var_name.clone().into(), + DUMMY_SP, + SyntaxContext::empty(), + )))) + }) + .collect(), + }); + + let closure_fn = Expr::Arrow(ArrowExpr { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + params: vec![], + body: Box::new(BlockStmtOrExpr::Expr(Box::new(closure_obj))), + is_async: false, + is_generator: false, + type_params: None, + return_type: None, + }); + + args.push(ExprOrSpread { + spread: None, + expr: Box::new(closure_fn), + }); + } + Expr::Call(CallExpr { span: DUMMY_SP, ctxt: SyntaxContext::empty(), @@ -1372,14 +1915,7 @@ impl StepTransform { })), }), }))), - args: vec![ExprOrSpread { - spread: None, - expr: Box::new(Expr::Lit(Lit::Str(Str { - span: DUMMY_SP, - value: step_id.into(), - raw: None, - }))), - }], + args, type_args: None, }) } @@ -1957,7 +2493,8 @@ impl StepTransform { .map(|fn_name| { // Check if this export name has a different const name (e.g., "default" -> "__default") let fn_name_str: &str = fn_name; - let actual_name = self.workflow_export_to_const_name + let actual_name = self + .workflow_export_to_const_name .get(fn_name_str) .map(|s| s.as_str()) .unwrap_or(fn_name_str); @@ -2202,11 +2739,22 @@ impl VisitMut for StepTransform { // No imports needed for workflow mode } TransformMode::Step => { - if !self.registration_calls.is_empty() + // Check what needs to be imported + let needs_register_import = !self.registration_calls.is_empty() || !self.object_property_step_functions.is_empty() - || !self.nested_step_functions.is_empty() - { - imports_to_add.push(self.create_register_import()); + || !self.nested_step_functions.is_empty(); + + // Check if any nested steps have closure variables + let needs_closure_import = self + .nested_step_functions + .iter() + .any(|(_, _, _, closure_vars, _)| !closure_vars.is_empty()); + + if needs_register_import || needs_closure_import { + imports_to_add.push(self.create_private_imports( + needs_register_import, + needs_closure_import, + )); } } TransformMode::Client => { @@ -2233,17 +2781,99 @@ impl VisitMut for StepTransform { // Process nested step functions FIRST (they typically appear earlier in source) let nested_functions: Vec<_> = self.nested_step_functions.drain(..).collect(); - for (fn_name, fn_expr, span) in nested_functions { - // Create a function declaration for the hoisted function - let hoisted_decl = ModuleItem::Stmt(Stmt::Decl(Decl::Fn(FnDecl { - ident: Ident::new( - fn_name.clone().into(), - DUMMY_SP, - SyntaxContext::empty(), - ), - function: fn_expr.function, - declare: false, - }))); + + for (fn_name, mut fn_expr, span, closure_vars, was_arrow) in nested_functions { + // If there are closure variables, add destructuring as first statement + if !closure_vars.is_empty() { + if let Some(body) = &mut fn_expr.function.body { + // Create destructuring statement: const { var1, var2 } = __private_getClosureVars(); + let closure_destructure = + Stmt::Decl(Decl::Var(Box::new(VarDecl { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + kind: VarDeclKind::Const, + decls: vec![VarDeclarator { + span: DUMMY_SP, + name: Pat::Object(ObjectPat { + span: DUMMY_SP, + props: closure_vars + .iter() + .map(|var_name| { + ObjectPatProp::Assign(AssignPatProp { + span: DUMMY_SP, + key: BindingIdent { + id: Ident::new( + var_name.clone().into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + type_ann: None, + }, + value: None, + }) + }) + .collect(), + optional: false, + type_ann: None, + }), + init: Some(Box::new(Expr::Call(CallExpr { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + callee: Callee::Expr(Box::new(Expr::Ident( + Ident::new( + "__private_getClosureVars".into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + ))), + args: vec![], + type_args: None, + }))), + definite: false, + }], + declare: false, + }))); + + // Prepend to function body + body.stmts.insert(0, closure_destructure); + } + } + + // Create the appropriate hoisted declaration based on original function type + let hoisted_decl = if was_arrow { + // Convert back to arrow function: var name = async () => { ... }; + let arrow_expr = self.convert_fn_expr_to_arrow(&fn_expr); + ModuleItem::Stmt(Stmt::Decl(Decl::Var(Box::new(VarDecl { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + kind: VarDeclKind::Var, + decls: vec![VarDeclarator { + span: DUMMY_SP, + name: Pat::Ident(BindingIdent { + id: Ident::new( + fn_name.clone().into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + type_ann: None, + }), + init: Some(Box::new(Expr::Arrow(arrow_expr))), + definite: false, + }], + declare: false, + })))) + } else { + // Keep as function declaration: async function name() { ... } + ModuleItem::Stmt(Stmt::Decl(Decl::Fn(FnDecl { + ident: Ident::new( + fn_name.clone().into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + function: fn_expr.function, + declare: false, + }))) + }; // Insert at current position and increment for next iteration module.body.insert(current_insert_pos, hoisted_decl); @@ -2415,7 +3045,7 @@ impl VisitMut for StepTransform { } TransformMode::Step => { if !self.registration_calls.is_empty() { - module_items.push(self.create_register_import()); + module_items.push(self.create_private_imports(true, false)); } } TransformMode::Client => { @@ -2586,7 +3216,26 @@ impl VisitMut for StepTransform { fn visit_mut_module_items(&mut self, items: &mut Vec) { // Collect all declared identifiers to avoid naming collisions self.collect_declared_identifiers(items); - + + // Collect module-level imports first + for item in items.iter() { + if let ModuleItem::ModuleDecl(ModuleDecl::Import(import_decl)) = item { + for specifier in &import_decl.specifiers { + match specifier { + ImportSpecifier::Named(named) => { + self.module_imports.insert(named.local.sym.to_string()); + } + ImportSpecifier::Default(default) => { + self.module_imports.insert(default.local.sym.to_string()); + } + ImportSpecifier::Namespace(namespace) => { + self.module_imports.insert(namespace.local.sym.to_string()); + } + } + } + } + } + // Check for file-level directives self.has_file_step_directive = self.check_module_directive(items); self.has_file_workflow_directive = self.check_module_workflow_directive(items); @@ -2922,8 +3571,9 @@ impl VisitMut for StepTransform { // Handle default workflow exports (workflow and client modes) // We need to: 1) find the export default position, 2) replace it with const declaration, // 3) add workflowId assignment, 4) add export default at the end - if (self.mode == TransformMode::Workflow || self.mode == TransformMode::Client) - && !self.default_workflow_exports.is_empty() { + if (self.mode == TransformMode::Workflow || self.mode == TransformMode::Client) + && !self.default_workflow_exports.is_empty() + { let default_workflows: Vec<_> = self.default_workflow_exports.drain(..).collect(); let default_exports: Vec<_> = self.default_exports_to_replace.drain(..).collect(); @@ -2931,8 +3581,8 @@ impl VisitMut for StepTransform { let mut export_position = None; for (i, item) in items.iter().enumerate() { match item { - ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultExpr(_)) | - ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultDecl(_)) => { + ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultExpr(_)) + | ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultDecl(_)) => { export_position = Some(i); break; } @@ -2947,35 +3597,46 @@ impl VisitMut for StepTransform { // Insert in correct order: const, workflowId, export default for (const_name, fn_expr, span) in default_workflows { // Insert const declaration at the original export position - items.insert(pos, ModuleItem::Stmt(Stmt::Decl(Decl::Var(Box::new(VarDecl { - span: DUMMY_SP, - ctxt: SyntaxContext::empty(), - kind: VarDeclKind::Const, - declare: false, - decls: vec![VarDeclarator { + items.insert( + pos, + ModuleItem::Stmt(Stmt::Decl(Decl::Var(Box::new(VarDecl { span: DUMMY_SP, - name: Pat::Ident(BindingIdent { - id: Ident::new(const_name.clone().into(), DUMMY_SP, SyntaxContext::empty()), - type_ann: None, - }), - init: Some(Box::new(fn_expr)), - definite: false, - }], - }))))); - + ctxt: SyntaxContext::empty(), + kind: VarDeclKind::Const, + declare: false, + decls: vec![VarDeclarator { + span: DUMMY_SP, + name: Pat::Ident(BindingIdent { + id: Ident::new( + const_name.clone().into(), + DUMMY_SP, + SyntaxContext::empty(), + ), + type_ann: None, + }), + init: Some(Box::new(fn_expr)), + definite: false, + }], + })))), + ); + // Insert workflowId assignment after const - items.insert(pos + 1, ModuleItem::Stmt( - self.create_workflow_id_assignment(&const_name, span), - )); + items.insert( + pos + 1, + ModuleItem::Stmt(self.create_workflow_id_assignment(&const_name, span)), + ); // Insert export default at the end (after workflowId) for (_export_name, replacement_expr) in &default_exports { - items.insert(pos + 2, ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultExpr( - ExportDefaultExpr { - span: DUMMY_SP, - expr: Box::new(replacement_expr.clone()), - }, - ))); + items.insert( + pos + 2, + ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultExpr( + ExportDefaultExpr { + span: DUMMY_SP, + expr: Box::new(replacement_expr.clone()), + }, + )), + ); } } } @@ -3928,6 +4589,9 @@ impl VisitMut for StepTransform { &mut cloned_arrow.body, ); + // Collect closure variables before conversion + let closure_vars = ClosureVariableCollector::collect_from_arrow_expr(&cloned_arrow, &self.module_imports); + // Create a function expression from the arrow function // (We need to convert it to a regular function for hoisting) let fn_expr = FnExpr { @@ -3981,6 +4645,8 @@ impl VisitMut for StepTransform { name.clone(), fn_expr, arrow_expr.span, + closure_vars, + true, // Was an arrow function )); // Mark the entire var declarator for removal by nulling out the init @@ -3995,9 +4661,13 @@ impl VisitMut for StepTransform { arrow_expr.span, false, ); - *init = Box::new( - self.create_step_proxy_reference(&step_id), - ); + + // Collect closure variables + let closure_vars = ClosureVariableCollector::collect_from_arrow_expr(&arrow_expr, &self.module_imports); + *init = Box::new(self.create_step_proxy_reference( + &step_id, + &closure_vars, + )); } TransformMode::Client => { // In client mode, remove the nested step @@ -4267,14 +4937,16 @@ impl VisitMut for StepTransform { let const_name = if fn_name == "default" { // Anonymous: generate unique name let unique_name = self.generate_unique_name("__default"); - self.workflow_export_to_const_name.insert("default".to_string(), unique_name.clone()); + self.workflow_export_to_const_name + .insert("default".to_string(), unique_name.clone()); unique_name } else { // Named: use the function name - self.workflow_export_to_const_name.insert("default".to_string(), fn_name.clone()); + self.workflow_export_to_const_name + .insert("default".to_string(), fn_name.clone()); fn_name.clone() }; - + // Always use "default" as the metadata key for default exports self.workflow_function_names.insert("default".to_string()); @@ -4294,7 +4966,7 @@ impl VisitMut for StepTransform { Expr::Fn(fn_expr.clone()), fn_expr.function.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( fn_name.clone(), @@ -4318,7 +4990,7 @@ impl VisitMut for StepTransform { TransformMode::Client => { // In client mode, replace workflow function body with error throw self.remove_use_workflow_directive(&mut fn_expr.function.body); - + let error_msg = format!( "You attempted to execute workflow {} function directly. To start a workflow, use start({}) from workflow/api", const_name, const_name @@ -4347,7 +5019,7 @@ impl VisitMut for StepTransform { arg: Box::new(error_expr), })]; } - + // For anonymous functions, convert to const declaration so we can assign workflowId if fn_name == "default" { // Track for const declaration and workflowId assignment @@ -4356,7 +5028,7 @@ impl VisitMut for StepTransform { Expr::Fn(fn_expr.clone()), fn_expr.function.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( fn_name.clone(), @@ -4445,8 +5117,9 @@ impl VisitMut for StepTransform { // Generate unique name first so we can use it in workflow_function_names let unique_name = self.generate_unique_name("__default"); // For function expression default exports, track mapping from "default" to actual const name - self.workflow_export_to_const_name.insert("default".to_string(), unique_name.clone()); - + self.workflow_export_to_const_name + .insert("default".to_string(), unique_name.clone()); + // Always use "default" as the metadata key for default exports self.workflow_function_names.insert("default".to_string()); @@ -4457,14 +5130,14 @@ impl VisitMut for StepTransform { TransformMode::Workflow => { // In workflow mode, convert to const declaration self.remove_use_workflow_directive(&mut fn_expr.function.body); - + // Track for const declaration and workflowId assignment self.default_workflow_exports.push(( unique_name.clone(), Expr::Fn(fn_expr.clone()), fn_expr.function.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( "default".to_string(), @@ -4506,14 +5179,14 @@ impl VisitMut for StepTransform { arg: Box::new(error_expr), })]; } - + // Track for const declaration and workflowId assignment self.default_workflow_exports.push(( unique_name.clone(), Expr::Fn(fn_expr.clone()), fn_expr.function.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( "default".to_string(), @@ -4545,11 +5218,12 @@ impl VisitMut for StepTransform { } else { // For arrow function default exports, generate unique name and track mapping let unique_name = self.generate_unique_name("__default"); - self.workflow_export_to_const_name.insert("default".to_string(), unique_name.clone()); - + self.workflow_export_to_const_name + .insert("default".to_string(), unique_name.clone()); + // Always use "default" as the metadata key for default exports self.workflow_function_names.insert("default".to_string()); - + match self.mode { TransformMode::Step => { // Workflow functions are not processed in step mode @@ -4557,14 +5231,14 @@ impl VisitMut for StepTransform { TransformMode::Workflow => { // In workflow mode, convert to const declaration self.remove_use_workflow_directive_arrow(&mut arrow_expr.body); - + // Track for const declaration and workflowId assignment self.default_workflow_exports.push(( unique_name.clone(), Expr::Arrow(arrow_expr.clone()), arrow_expr.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( "default".to_string(), @@ -4609,14 +5283,14 @@ impl VisitMut for StepTransform { arg: Box::new(error_expr), })], })); - + // Track for const declaration and workflowId assignment self.default_workflow_exports.push(( unique_name.clone(), Expr::Arrow(arrow_expr.clone()), arrow_expr.span, )); - + // Track for replacement with identifier self.default_exports_to_replace.push(( "default".to_string(), @@ -4710,6 +5384,9 @@ impl VisitMut for StepTransform { &mut cloned_arrow.body, ); + // Collect closure variables + let closure_vars = ClosureVariableCollector::collect_from_arrow_expr(&cloned_arrow, &self.module_imports); + // Convert to function expression let fn_expr = FnExpr { ident: Some(Ident::new( @@ -4769,6 +5446,8 @@ impl VisitMut for StepTransform { generated_name.clone(), fn_expr, arrow_expr.span, + closure_vars, + true, // Was an arrow function )); // Replace with identifier reference @@ -4788,8 +5467,14 @@ impl VisitMut for StepTransform { arrow_expr.span, false, ); + + // Collect closure variables + let closure_vars = ClosureVariableCollector::collect_from_arrow_expr(&arrow_expr, &self.module_imports); *kv_prop.value = self - .create_step_proxy_reference(&step_id); + .create_step_proxy_reference( + &step_id, + &closure_vars, + ); } TransformMode::Client => { // Just remove directive @@ -4826,6 +5511,9 @@ impl VisitMut for StepTransform { &mut cloned_fn.function.body, ); + // Collect closure variables + let closure_vars = ClosureVariableCollector::collect_from_function(&*cloned_fn.function, &self.module_imports); + let hoisted_fn_expr = FnExpr { ident: Some(Ident::new( generated_name.clone().into(), @@ -4839,6 +5527,8 @@ impl VisitMut for StepTransform { generated_name.clone(), hoisted_fn_expr, fn_expr.function.span, + closure_vars, + false, // Was a function expression )); // Replace with identifier reference @@ -4858,8 +5548,14 @@ impl VisitMut for StepTransform { fn_expr.function.span, false, ); + + // Collect closure variables + let closure_vars = ClosureVariableCollector::collect_from_function(&fn_expr.function, &self.module_imports); *kv_prop.value = self - .create_step_proxy_reference(&step_id); + .create_step_proxy_reference( + &step_id, + &closure_vars, + ); } TransformMode::Client => { // Just remove directive @@ -4899,6 +5595,13 @@ impl VisitMut for StepTransform { &mut cloned_function.body, ); + // Collect closure variables + let closure_vars = + ClosureVariableCollector::collect_from_function( + &cloned_function, + &self.module_imports, + ); + let fn_expr = FnExpr { ident: Some(Ident::new( generated_name.clone().into(), @@ -4912,6 +5615,8 @@ impl VisitMut for StepTransform { generated_name.clone(), fn_expr, method_prop.function.span, + closure_vars, + false, // Was a method )); // Replace method with property pointing to identifier @@ -4936,6 +5641,13 @@ impl VisitMut for StepTransform { false, ); + // Collect closure variables + let closure_vars = + ClosureVariableCollector::collect_from_function( + &method_prop.function, + &self.module_imports, + ); + // Replace method with property pointing to proxy *boxed_prop = Box::new(Prop::KeyValue(KeyValueProp { @@ -4943,6 +5655,7 @@ impl VisitMut for StepTransform { value: Box::new( self.create_step_proxy_reference( &step_id, + &closure_vars, ), ), })); diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-in-workflow/output-step.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-in-workflow/output-step.js index 619f89ee5..9022b31a7 100644 --- a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-in-workflow/output-step.js +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-in-workflow/output-step.js @@ -4,15 +4,9 @@ import { registerStepFunction } from "workflow/internal/private"; async function step(a, b) { return a + b; } -async function arrowStep(x, y) { - return x * y; -} -async function letArrowStep(x, y) { - return x - y; -} -async function varArrowStep(x, y) { - return x / y; -} +var arrowStep = async (x, y)=>x * y; +var letArrowStep = async (x, y)=>x - y; +var varArrowStep = async (x, y)=>x / y; var helpers$objectStep = async (x, y)=>{ return x + y + 10; }; diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/input.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/input.js new file mode 100644 index 000000000..d209a5188 --- /dev/null +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/input.js @@ -0,0 +1,30 @@ +import { DurableAgent } from '@workflow/ai/agent'; +import { gateway } from 'ai'; + +export async function wflow() { + 'use workflow'; + let count = 42; + + async function namedStepWithClosureVars() { + 'use step'; + console.log('count', count); + } + + const agent = new DurableAgent({ + arrowFunctionWithClosureVars: async () => { + 'use step'; + console.log('count', count); + return gateway('openai/gpt-5'); + }, + + namedFunctionWithClosureVars: async function() { + 'use step'; + console.log('count', count); + }, + + async methodWithClosureVars() { + 'use step'; + console.log('count', count); + }, + }); +} diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-client.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-client.js new file mode 100644 index 000000000..043e0ec09 --- /dev/null +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-client.js @@ -0,0 +1,5 @@ +/**__internal_workflows{"workflows":{"input.js":{"wflow":{"workflowId":"workflow//input.js//wflow"}}}}*/; +export async function wflow() { + throw new Error("You attempted to execute workflow wflow function directly. To start a workflow, use start(wflow) from workflow/api"); +} +wflow.workflowId = "workflow//input.js//wflow"; diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-step.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-step.js new file mode 100644 index 000000000..82c9b1509 --- /dev/null +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-step.js @@ -0,0 +1,33 @@ +import { __private_getClosureVars, registerStepFunction } from "workflow/internal/private"; +import { DurableAgent } from '@workflow/ai/agent'; +import { gateway } from 'ai'; +/**__internal_workflows{"workflows":{"input.js":{"wflow":{"workflowId":"workflow//input.js//wflow"}}},"steps":{"input.js":{"_anonymousStep0":{"stepId":"step//input.js//_anonymousStep0"},"_anonymousStep1":{"stepId":"step//input.js//_anonymousStep1"},"_anonymousStep2":{"stepId":"step//input.js//_anonymousStep2"},"namedStepWithClosureVars":{"stepId":"step//input.js//namedStepWithClosureVars"}}}}*/; +async function namedStepWithClosureVars() { + const { count } = __private_getClosureVars(); + console.log('count', count); +} +var _anonymousStep0 = async ()=>{ + const { count } = __private_getClosureVars(); + console.log('count', count); + return gateway('openai/gpt-5'); +}; +async function _anonymousStep1() { + const { count } = __private_getClosureVars(); + console.log('count', count); +} +async function _anonymousStep2() { + const { count } = __private_getClosureVars(); + console.log('count', count); +} +export async function wflow() { + let count = 42; + const agent = new DurableAgent({ + arrowFunctionWithClosureVars: _anonymousStep0, + namedFunctionWithClosureVars: _anonymousStep1, + methodWithClosureVars: _anonymousStep2 + }); +} +registerStepFunction("step//input.js//namedStepWithClosureVars", namedStepWithClosureVars); +registerStepFunction("step//input.js//_anonymousStep0", _anonymousStep0); +registerStepFunction("step//input.js//_anonymousStep1", _anonymousStep1); +registerStepFunction("step//input.js//_anonymousStep2", _anonymousStep2); diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-workflow.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-workflow.js new file mode 100644 index 000000000..f9e6391de --- /dev/null +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-step-with-closure/output-workflow.js @@ -0,0 +1,20 @@ +import { DurableAgent } from '@workflow/ai/agent'; +/**__internal_workflows{"workflows":{"input.js":{"wflow":{"workflowId":"workflow//input.js//wflow"}}},"steps":{"input.js":{"_anonymousStep0":{"stepId":"step//input.js//_anonymousStep0"},"_anonymousStep1":{"stepId":"step//input.js//_anonymousStep1"},"_anonymousStep2":{"stepId":"step//input.js//_anonymousStep2"},"namedStepWithClosureVars":{"stepId":"step//input.js//namedStepWithClosureVars"}}}}*/; +export async function wflow() { + let count = 42; + var namedStepWithClosureVars = globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step//input.js//namedStepWithClosureVars", ()=>({ + count + })); + const agent = new DurableAgent({ + arrowFunctionWithClosureVars: globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step//input.js//_anonymousStep0", ()=>({ + count + })), + namedFunctionWithClosureVars: globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step//input.js//_anonymousStep1", ()=>({ + count + })), + methodWithClosureVars: globalThis[Symbol.for("WORKFLOW_USE_STEP")]("step//input.js//_anonymousStep2", ()=>({ + count + })) + }); +} +wflow.workflowId = "workflow//input.js//wflow"; diff --git a/packages/swc-plugin-workflow/transform/tests/fixture/nested-steps-in-object-constructor/output-step.js b/packages/swc-plugin-workflow/transform/tests/fixture/nested-steps-in-object-constructor/output-step.js index 415f8a45e..b1cd46531 100644 --- a/packages/swc-plugin-workflow/transform/tests/fixture/nested-steps-in-object-constructor/output-step.js +++ b/packages/swc-plugin-workflow/transform/tests/fixture/nested-steps-in-object-constructor/output-step.js @@ -3,12 +3,8 @@ import { DurableAgent } from '@workflow/ai/agent'; import { gateway, tool } from 'ai'; import * as z from 'zod'; /**__internal_workflows{"workflows":{"input.js":{"test":{"workflowId":"workflow//input.js//test"}}},"steps":{"input.js":{"_anonymousStep0":{"stepId":"step//input.js//_anonymousStep0"},"_anonymousStep1":{"stepId":"step//input.js//_anonymousStep1"}}}}*/; -async function _anonymousStep0() { - return gateway('openai/gpt-5'); -} -async function _anonymousStep1({ location }) { - return `Weather in ${location}: Sunny, 72°F`; -} +var _anonymousStep0 = async ()=>gateway('openai/gpt-5'); +var _anonymousStep1 = async ({ location })=>`Weather in ${location}: Sunny, 72°F`; export async function test() { 'use workflow'; const agent = new DurableAgent({ diff --git a/packages/web-shared/src/sidebar/attribute-panel.tsx b/packages/web-shared/src/sidebar/attribute-panel.tsx index 0786276ea..adb110f9c 100644 --- a/packages/web-shared/src/sidebar/attribute-panel.tsx +++ b/packages/web-shared/src/sidebar/attribute-panel.tsx @@ -105,6 +105,36 @@ const attributeToDisplayFn: Record< // Resolved attributes, won't actually use this function metadata: JsonBlock, input: (value: unknown) => { + // Check if input has args + closure vars structure + if (value && typeof value === 'object' && 'args' in value) { + const { args, closureVars } = value as { + args: unknown[]; + closureVars?: Record; + }; + const argCount = Array.isArray(args) ? args.length : 0; + const hasClosureVars = closureVars && Object.keys(closureVars).length > 0; + + return ( + <> + + {Array.isArray(args) + ? args.map((v, i) => ( +
+ {JsonBlock(v)} +
+ )) + : JsonBlock(args)} +
+ {hasClosureVars && ( + + {JsonBlock(closureVars)} + + )} + + ); + } + + // Fallback: treat as plain array or object const argCount = Array.isArray(value) ? value.length : 0; return ( diff --git a/packages/world/src/steps.ts b/packages/world/src/steps.ts index 8a51fd4ae..8c973f6b9 100644 --- a/packages/world/src/steps.ts +++ b/packages/world/src/steps.ts @@ -41,7 +41,7 @@ export type Step = z.infer; export interface CreateStepRequest { stepId: string; stepName: string; - input: SerializedData[]; + input: SerializedData; } export interface UpdateStepRequest { diff --git a/workbench/example/workflows/99_e2e.ts b/workbench/example/workflows/99_e2e.ts index bd4f1faa7..7e442e52e 100644 --- a/workbench/example/workflows/99_e2e.ts +++ b/workbench/example/workflows/99_e2e.ts @@ -512,3 +512,21 @@ async function doubleNumber(x: number) { 'use step'; return x * 2; } + +////////////////////////////////////////////////////////// + +export async function closureVariableWorkflow(baseValue: number) { + 'use workflow'; + let multiplier = 3; + const prefix = 'Result: '; + + // Nested step function that uses closure variables + const calculate = async () => { + 'use step'; + const result = baseValue * multiplier; + return `${prefix}${result}`; + }; + + const output = await calculate(); + return output; +}