diff --git a/packages/typegpu/src/resolutionCtx.ts b/packages/typegpu/src/resolutionCtx.ts index abf6e26cbd..a75af493e7 100644 --- a/packages/typegpu/src/resolutionCtx.ts +++ b/packages/typegpu/src/resolutionCtx.ts @@ -163,6 +163,7 @@ class ItemStateStackImpl implements ItemStateStack { this._stack.push({ type: 'blockScope', declarations: new Map(), + externals: new Map(), }); } @@ -232,7 +233,8 @@ class ItemStateStackImpl implements ItemStateStack { } if (layer?.type === 'blockScope') { - const snippet = layer.declarations.get(id); + // the order matters + const snippet = layer.declarations.get(id) ?? layer.externals.get(id); if (snippet !== undefined) { return snippet; } @@ -260,6 +262,30 @@ class ItemStateStackImpl implements ItemStateStack { throw new Error('No block scope found to define a variable in.'); } + + setBlockExternals(externals: Record) { + for (let i = this._stack.length - 1; i >= 0; --i) { + const layer = this._stack[i]; + if (layer?.type === 'blockScope') { + Object.entries(externals).forEach(([id, snippet]) => { + layer.externals.set(id, snippet); + }); + return; + } + } + throw new Error('No block scope found to set externals in.'); + } + + clearBlockExternals() { + for (let i = this._stack.length - 1; i >= 0; --i) { + const layer = this._stack[i]; + if (layer?.type === 'blockScope') { + layer.externals.clear(); + return; + } + } + throw new Error('No block scope found to clear externals in.'); + } } const INDENT = [ @@ -429,6 +455,14 @@ export class ResolutionCtxImpl implements ResolutionCtx { this._itemStateStack.pop('blockScope'); } + setBlockExternals(externals: Record) { + this._itemStateStack.setBlockExternals(externals); + } + + clearBlockExternals() { + this._itemStateStack.clearBlockExternals(); + } + generateLog(op: string, args: Snippet[]): Snippet { return this.#logGenerator.generateLog(this, op, args); } diff --git a/packages/typegpu/src/tgsl/generationHelpers.ts b/packages/typegpu/src/tgsl/generationHelpers.ts index d655b02f35..94b7a33489 100644 --- a/packages/typegpu/src/tgsl/generationHelpers.ts +++ b/packages/typegpu/src/tgsl/generationHelpers.ts @@ -92,6 +92,8 @@ export type GenerationCtx = ResolutionCtx & { generateLog(op: string, args: Snippet[]): Snippet; getById(id: string): Snippet | null; defineVariable(id: string, snippet: Snippet): void; + setBlockExternals(externals: Record): void; + clearBlockExternals(): void; /** * Types that are used in `return` statements are diff --git a/packages/typegpu/src/tgsl/shaderGenerator.ts b/packages/typegpu/src/tgsl/shaderGenerator.ts index 5c467459d2..90531d294a 100644 --- a/packages/typegpu/src/tgsl/shaderGenerator.ts +++ b/packages/typegpu/src/tgsl/shaderGenerator.ts @@ -2,10 +2,11 @@ import type { Block, Expression, Statement } from 'tinyest'; import type { Snippet } from '../data/snippet.ts'; import type { GenerationCtx } from './generationHelpers.ts'; import type { BaseData } from '../data/wgslTypes.ts'; +import type { ExternalMap } from '../core/resolve/externals.ts'; export interface ShaderGenerator { initGenerator(ctx: GenerationCtx): void; - block(body: Block): string; + block(body: Block, externalMap?: ExternalMap): string; identifier(id: string): Snippet; typedExpression(expression: Expression, expectedType: BaseData): Snippet; expression(expression: Expression): Snippet; diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index c6031d5149..b16d788740 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -36,6 +36,7 @@ import { } from './conversion.ts'; import { ArrayExpression, + coerceToSnippet, concretize, type GenerationCtx, numericLiteralToSnippet, @@ -51,6 +52,7 @@ import type { AnyFn } from '../core/function/fnTypes.ts'; import { arrayLength } from '../std/array.ts'; import { AutoStruct } from '../data/autoStruct.ts'; import { mathToStd } from './math.ts'; +import type { ExternalMap } from '../core/resolve/externals.ts'; const { NodeTypeCatalog: NODE } = tinyest; @@ -200,8 +202,19 @@ class WgslGenerator implements ShaderGenerator { public block( [_, statements]: tinyest.Block, + externalMap?: ExternalMap, ): string { this.ctx.pushBlockScope(); + + if (externalMap) { + const externals = Object.fromEntries( + Object.entries(externalMap).map(( + [id, value], + ) => [id, coerceToSnippet(value)]), + ); + this.ctx.setBlockExternals(externals); + } + try { this.ctx.indent(); const body = statements.map((statement) => this.statement(statement)) @@ -1233,8 +1246,6 @@ ${this.ctx.pre}else ${alternate}`; // If it's ephemeral, it's a value that cannot change. If it's a reference, we take // an implicit pointer to it let loopVarKind = 'let'; - const loopVarName = this.ctx.makeNameValid(loopVar[1]); - if (!isEphemeralSnippet(elementSnippet)) { if (elementSnippet.origin === 'constant-tgpu-const-ref') { loopVarKind = 'const'; @@ -1260,21 +1271,13 @@ ${this.ctx.pre}else ${alternate}`; } } - const loopVarSnippet = snip( - loopVarName, - elementType, - elementSnippet.origin, - ); - - this.ctx.defineVariable(loopVar[1], loopVarSnippet); - const forStr = stitch`${this.ctx.pre}for (var ${index} = 0u; ${index} < ${ tryConvertSnippet(this.ctx, elementCountSnippet, u32, false) }; ${index}++) {`; - this.ctx.indent(); + const loopVarName = this.ctx.makeNameValid(loopVar[1]); const loopVarDeclStr = stitch`${this.ctx.pre}${loopVarKind} ${loopVarName} = ${ tryConvertSnippet( @@ -1286,7 +1289,9 @@ ${this.ctx.pre}else ${alternate}`; };`; const bodyStr = `${this.ctx.pre}${ - this.block(blockifySingleStatement(body)) + this.block(blockifySingleStatement(body), { + [loopVar[1]]: snip(loopVarName, elementType, elementSnippet.origin), + }) }`; this.ctx.dedent(); diff --git a/packages/typegpu/src/types.ts b/packages/typegpu/src/types.ts index ba6e7ca7a5..1125a50c03 100644 --- a/packages/typegpu/src/types.ts +++ b/packages/typegpu/src/types.ts @@ -124,6 +124,7 @@ export type SlotBindingLayer = { export type BlockScopeLayer = { type: 'blockScope'; declarations: Map; + externals: Map; }; export type StackLayer = @@ -151,6 +152,8 @@ export interface ItemStateStack { externalMap: Record, ): FunctionScopeLayer; pushBlockScope(): void; + setBlockExternals(externals: Record): void; + clearBlockExternals(): void; pop(type: T): Extract; pop(): StackLayer | undefined; diff --git a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts index 90a2e895b3..fe34208d2b 100644 --- a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts +++ b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts @@ -1740,4 +1740,109 @@ describe('wgslGenerator', () => { }" `); }); + + it('block externals do not override identifiers', () => { + const f = () => { + 'use gpu'; + const y = 100; + const x = y; + return x; + }; + + const parsed = getMetaData(f)?.ast?.body as tinyest.Block; + + provideCtx(ctx, () => { + ctx[$internal].itemStateStack.pushFunctionScope( + 'normal', + [], + {}, + d.u32, + {}, + ); + + const res = wgslGenerator.block( + parsed, + { x: 42 }, + ); + + expect(res).toMatchInlineSnapshot(` + "{ + const y = 100; + const x = y; + return u32(x); + }" + `); + }); + }); + + it('block externals are injected correctly', () => { + const f = () => { + 'use gpu'; + for (const x of []) { + const y = x; + } + }; + + const parsed = getMetaData(f)?.ast?.body as tinyest.Block; + + provideCtx(ctx, () => { + ctx[$internal].itemStateStack.pushFunctionScope( + 'normal', + [], + {}, + d.Void, + {}, + ); + + const res = wgslGenerator.block( + (parsed[1][0] as tinyest.ForOf)[3] as tinyest.Block, + { x: 67 }, + ); + + expect(res).toMatchInlineSnapshot(` + "{ + const y = 67; + }" + `); + }); + }); + + it('block externals are respected in nested blocks', () => { + const f = () => { + 'use gpu'; + let result = d.i32(0); + const list = d.arrayOf(d.i32, 3)([1, 2, 3]); + for (const elem of list) { + { + // We use the `elem` in a nested block + result += elem; + } + } + }; + + const parsed = getMetaData(f)?.ast?.body as tinyest.Block; + + provideCtx(ctx, () => { + ctx[$internal].itemStateStack.pushFunctionScope( + 'normal', + [], + {}, + d.Void, + {}, + ); + + const res = wgslGenerator.block( + (parsed[1][2] as tinyest.ForOf)[3] as tinyest.Block, + { result: snip('result', d.i32, 'function'), elem: 7 }, + ); + + expect(res).toMatchInlineSnapshot(` + "{ + { + result += 7i; + } + }" + `); + }); + }); });