From 3095c42f3b0cb36116ae52c0cb08cc0f3ca91dde Mon Sep 17 00:00:00 2001 From: Konrad Reczko Date: Fri, 30 Jan 2026 15:11:45 +0100 Subject: [PATCH 1/3] refine std function and conversion handling --- .../src/core/function/createCallableSchema.ts | 89 ++++++ .../typegpu/src/core/function/dualImpl.ts | 10 +- packages/typegpu/src/data/matrix.ts | 4 +- packages/typegpu/src/data/numeric.ts | 25 +- packages/typegpu/src/data/vector.ts | 5 +- packages/typegpu/src/errors.ts | 17 ++ packages/typegpu/src/std/index.ts | 1 + packages/typegpu/src/std/numeric.ts | 280 +++++++++++------- packages/typegpu/src/std/operators.ts | 118 +++++--- packages/typegpu/src/std/texture.ts | 132 ++++++++- packages/typegpu/src/tgsl/conversion.ts | 71 +++-- .../typegpu/src/tgsl/generationHelpers.ts | 18 +- packages/typegpu/src/tgsl/wgslGenerator.ts | 4 +- packages/typegpu/src/types.ts | 2 +- .../tests/examples/individual/disco.test.ts | 14 +- .../examples/individual/image-tuning.test.ts | 2 +- .../examples/individual/jelly-slider.test.ts | 4 +- .../examples/individual/jelly-switch.test.ts | 2 +- .../individual/jump-flood-distance.test.ts | 2 +- .../individual/tgsl-parsing-test.test.ts | 4 +- packages/typegpu/tests/indent.test.ts | 4 +- packages/typegpu/tests/primitiveCast.test.ts | 2 +- .../tests/std/texture/textureGather.test.ts | 109 +++++++ .../typegpu/tests/tgsl/typeInference.test.ts | 10 +- .../typegpu/tests/tgsl/wgslGenerator.test.ts | 4 +- packages/typegpu/tests/vector.test.ts | 2 +- 26 files changed, 704 insertions(+), 231 deletions(-) create mode 100644 packages/typegpu/src/core/function/createCallableSchema.ts create mode 100644 packages/typegpu/tests/std/texture/textureGather.test.ts diff --git a/packages/typegpu/src/core/function/createCallableSchema.ts b/packages/typegpu/src/core/function/createCallableSchema.ts new file mode 100644 index 0000000000..7266c7ac6e --- /dev/null +++ b/packages/typegpu/src/core/function/createCallableSchema.ts @@ -0,0 +1,89 @@ +import { type MapValueToSnippet, snip } from '../../data/snippet.ts'; +import { type BaseData, isPtr } from '../../data/wgslTypes.ts'; +import { setName } from '../../shared/meta.ts'; +import { $gpuCallable } from '../../shared/symbols.ts'; +import { tryConvertSnippet } from '../../tgsl/conversion.ts'; +import { + type DualFn, + isKnownAtComptime, + NormalState, + type ResolutionCtx, +} from '../../types.ts'; + +type MapValueToDataType = { [K in keyof T]: BaseData }; +type AnyFn = (...args: never[]) => unknown; + +interface CallableSchemaOptions { + readonly name: string; + readonly normalImpl: T; + readonly codegenImpl: ( + ctx: ResolutionCtx, + args: MapValueToSnippet>, + ) => string; + readonly signature: ( + ...inArgTypes: MapValueToDataType> + ) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData }; +} + +export function callableSchema( + options: CallableSchemaOptions, +): DualFn { + const impl = ((...args: Parameters) => { + return options.normalImpl(...args); + }) as DualFn; + + setName(impl, options.name); + impl.toString = () => options.name; + impl[$gpuCallable] = { + get strictSignature() { + return undefined; + }, + call(ctx, args) { + const { argTypes, returnType } = options.signature( + ...args.map((s) => { + // Dereference implicit pointers + if (isPtr(s.dataType) && s.dataType.implicit) { + return s.dataType.inner; + } + return s.dataType; + }) as MapValueToDataType>, + ); + + const converted = args.map((s, idx) => { + const argType = argTypes[idx]; + if (!argType) { + throw new Error('Function called with invalid arguments'); + } + return tryConvertSnippet( + ctx, + s, + argType, + false, + ); + }) as MapValueToSnippet>; + + if (converted.every((s) => isKnownAtComptime(s))) { + ctx.pushMode(new NormalState()); + try { + return snip( + options.normalImpl(...converted.map((s) => s.value) as never[]), + returnType, + // Functions give up ownership of their return value + /* origin */ 'constant', + ); + } finally { + ctx.popMode('normal'); + } + } + + return snip( + options.codegenImpl(ctx, converted), + returnType, + // Functions give up ownership of their return value + /* origin */ 'runtime', + ); + }, + }; + + return impl; +} diff --git a/packages/typegpu/src/core/function/dualImpl.ts b/packages/typegpu/src/core/function/dualImpl.ts index da41f49262..c72200b326 100644 --- a/packages/typegpu/src/core/function/dualImpl.ts +++ b/packages/typegpu/src/core/function/dualImpl.ts @@ -2,6 +2,7 @@ import { type MapValueToSnippet, snip } from '../../data/snippet.ts'; import { setName } from '../../shared/meta.ts'; import { $gpuCallable } from '../../shared/symbols.ts'; import { tryConvertSnippet } from '../../tgsl/conversion.ts'; +import { concretize } from '../../tgsl/generationHelpers.ts'; import { type DualFn, isKnownAtComptime, @@ -21,10 +22,13 @@ interface DualImplOptions { args: MapValueToSnippet>, ) => string; readonly signature: - | { argTypes: BaseData[]; returnType: BaseData } + | { + argTypes: (BaseData | BaseData[])[]; + returnType: BaseData; + } | (( ...inArgTypes: MapValueToDataType> - ) => { argTypes: BaseData[]; returnType: BaseData }); + ) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData }); /** * Whether the function should skip trying to execute the "normal" implementation if * all arguments are known at compile time. @@ -112,7 +116,7 @@ export function dualImpl( return snip( options.codegenImpl(ctx, converted), - returnType, + concretize(returnType), // Functions give up ownership of their return value /* origin */ 'runtime', ); diff --git a/packages/typegpu/src/data/matrix.ts b/packages/typegpu/src/data/matrix.ts index 839d137a16..008d493c5b 100644 --- a/packages/typegpu/src/data/matrix.ts +++ b/packages/typegpu/src/data/matrix.ts @@ -1,4 +1,5 @@ import { comptime } from '../core/function/comptime.ts'; +import { callableSchema } from '../core/function/createCallableSchema.ts'; import { dualImpl } from '../core/function/dualImpl.ts'; import { stitch } from '../core/resolve/stitch.ts'; import { $repr } from '../shared/symbols.ts'; @@ -64,7 +65,7 @@ function createMatSchema< >( options: MatSchemaOptions, ): { type: TType; [$repr]: ValueType } & MatConstructor { - const construct = dualImpl({ + const construct = callableSchema({ name: options.type, normalImpl: (...args: (number | ColumnType)[]): ValueType => { const elements: number[] = []; @@ -94,7 +95,6 @@ function createMatSchema< return new options.MatImpl(...elements) as ValueType; }, - ignoreImplicitCastWarning: true, signature: (...args) => ({ argTypes: args.map((arg) => (isVec(arg) ? arg : f32)), returnType: schema as unknown as BaseData, diff --git a/packages/typegpu/src/data/numeric.ts b/packages/typegpu/src/data/numeric.ts index 590729efb2..dd74d292c6 100644 --- a/packages/typegpu/src/data/numeric.ts +++ b/packages/typegpu/src/data/numeric.ts @@ -1,5 +1,4 @@ import { stitch } from '../core/resolve/stitch.ts'; -import { dualImpl } from '../core/function/dualImpl.ts'; import { $internal } from '../shared/symbols.ts'; import type { AbstractFloat, @@ -11,6 +10,7 @@ import type { U16, U32, } from './wgslTypes.ts'; +import { callableSchema } from '../core/function/createCallableSchema.ts'; export const abstractInt = { [$internal]: {}, @@ -28,7 +28,7 @@ export const abstractFloat = { }, } as AbstractFloat; -const boolCast = dualImpl({ +const boolCast = callableSchema({ name: 'bool', signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: bool }), normalImpl(v?: number | boolean) { @@ -66,7 +66,7 @@ export const bool: Bool = Object.assign(boolCast, { type: 'bool', }) as unknown as Bool; -const u32Cast = dualImpl({ +const u32Cast = callableSchema({ name: 'u32', signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: u32 }), normalImpl(v?: number | boolean) { @@ -76,6 +76,17 @@ const u32Cast = dualImpl({ if (typeof v === 'boolean') { return v ? 1 : 0; } + if (!Number.isInteger(v)) { + const truncated = Math.trunc(v); + if (truncated < 0) { + return 0; + } + if (truncated > 0xffffffff) { + return 0xffffffff; + } + return truncated; + } + // Integer input: treat as bit reinterpretation (i32 -> u32) return (v & 0xffffffff) >>> 0; }, codegenImpl: (_ctx, [arg]) => @@ -106,7 +117,7 @@ export const u32: U32 = Object.assign(u32Cast, { type: 'u32', }) as unknown as U32; -const i32Cast = dualImpl({ +const i32Cast = callableSchema({ name: 'i32', signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: i32 }), normalImpl(v?: number | boolean) { @@ -149,9 +160,9 @@ export const i32: I32 = Object.assign(i32Cast, { type: 'i32', }) as unknown as I32; -const f32Cast = dualImpl({ +const f32Cast = callableSchema({ name: 'f32', - signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f32 }), + signature: (arg) => ({ argTypes: [arg ? [arg] : []], returnType: f32 }), normalImpl(v?: number | boolean) { if (v === undefined) { return 0; @@ -275,7 +286,7 @@ function roundToF16(x: number): number { return fromHalfBits(toHalfBits(x)); } -const f16Cast = dualImpl({ +const f16Cast = callableSchema({ name: 'f16', signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f16 }), normalImpl(v?: number | boolean) { diff --git a/packages/typegpu/src/data/vector.ts b/packages/typegpu/src/data/vector.ts index 2c20430b10..3ff86f10ea 100644 --- a/packages/typegpu/src/data/vector.ts +++ b/packages/typegpu/src/data/vector.ts @@ -1,4 +1,4 @@ -import { dualImpl } from '../core/function/dualImpl.ts'; +import { callableSchema } from '../core/function/createCallableSchema.ts'; import { stitch } from '../core/resolve/stitch.ts'; import { $internal, $repr } from '../shared/symbols.ts'; import { bool, f16, f32, i32, u32 } from './numeric.ts'; @@ -307,14 +307,13 @@ function makeVecSchema( ); }; - const construct = dualImpl({ + const construct = callableSchema({ name: type, signature: (...args) => ({ argTypes: args.map((arg) => isVec(arg) ? arg : primitive), returnType: schema as unknown as BaseData, }), normalImpl: cpuConstruct, - ignoreImplicitCastWarning: true, codegenImpl: (_ctx, args) => { if ( args.length === 1 && args[0]?.dataType === schema as unknown as BaseData diff --git a/packages/typegpu/src/errors.ts b/packages/typegpu/src/errors.ts index 1afd316749..9b878ec86f 100644 --- a/packages/typegpu/src/errors.ts +++ b/packages/typegpu/src/errors.ts @@ -210,3 +210,20 @@ export class WgslTypeError extends Error { Object.setPrototypeOf(this, WgslTypeError.prototype); } } + +export class SignatureNotSupportedError extends Error { + constructor(actual: BaseData[], candidates: BaseData[]) { + super( + `Unsupported data types: ${ + actual.map((a) => a.type).join(', ') + }. Supported types are: ${ + candidates + .map((r) => r.type) + .join(', ') + }.`, + ); + + // Set the prototype explicitly. + Object.setPrototypeOf(this, SignatureNotSupportedError.prototype); + } +} diff --git a/packages/typegpu/src/std/index.ts b/packages/typegpu/src/std/index.ts index b5aa26877f..30d74cbfcf 100644 --- a/packages/typegpu/src/std/index.ts +++ b/packages/typegpu/src/std/index.ts @@ -150,6 +150,7 @@ export { export { textureDimensions, + textureGather, textureLoad, textureSample, textureSampleBaseClampToEdge, diff --git a/packages/typegpu/src/std/numeric.ts b/packages/typegpu/src/std/numeric.ts index 1bc103fe92..6c87344182 100644 --- a/packages/typegpu/src/std/numeric.ts +++ b/packages/typegpu/src/std/numeric.ts @@ -1,5 +1,7 @@ import { dualImpl, MissingCpuImplError } from '../core/function/dualImpl.ts'; import { stitch } from '../core/resolve/stitch.ts'; +import type { AnyData } from '../data/dataTypes.ts'; +import { mat2x2f, mat3x3f, mat4x4f } from '../data/matrix.ts'; import { smoothstepScalar } from '../data/numberOps.ts'; import { abstractFloat, @@ -15,12 +17,15 @@ import { vec2f, vec2h, vec2i, + vec2u, vec3f, vec3h, vec3i, + vec3u, vec4f, vec4h, vec4i, + vec4u, } from '../data/vector.ts'; import { VectorOps } from '../data/vectorOps.ts'; import { @@ -30,10 +35,8 @@ import { type AnyMatInstance, type AnyNumericVecInstance, type AnySignedVecInstance, - type AnyWgslData, type BaseData, isHalfPrecisionSchema, - isNumericSchema, isVecInstance, type v2f, type v2h, @@ -44,8 +47,10 @@ import { type v4f, type v4h, type v4i, + type Vec2f, type VecData, } from '../data/wgslTypes.ts'; +import { SignatureNotSupportedError } from '../errors.ts'; import type { Infer } from '../shared/repr.ts'; import { unify } from '../tgsl/conversion.ts'; import type { ResolutionCtx } from '../types.ts'; @@ -62,6 +67,18 @@ const unaryIdentitySignature = (arg: BaseData) => { }; }; +const unaryIdentityRestrictedSignature = + (restrict: BaseData[]) => (arg: BaseData) => { + const argRestricted = unify([arg], restrict); + if (!argRestricted) { + throw new SignatureNotSupportedError([arg], restrict); + } + return { + argTypes: argRestricted, + returnType: argRestricted[0] as BaseData, + }; + }; + const variadicUnifySignature = (...args: BaseData[]) => { const uargs = unify(args) ?? args; return ({ @@ -70,6 +87,20 @@ const variadicUnifySignature = (...args: BaseData[]) => { }); }; +const variadicUnifyRestrictedSignature = (restrict: BaseData[]) => +( + ...args: BaseData[] +) => { + const uargs = unify(args, restrict); + if (!uargs) { + throw new SignatureNotSupportedError(args, restrict); + } + return ({ + argTypes: uargs, + returnType: uargs[0] as BaseData, + }); +}; + function variadicReduce(fn: (a: T, b: T) => T) { return (fst: T, ...rest: T[]): T => { let acc = fst; @@ -93,6 +124,23 @@ function variadicStitch(wrapper: string) { }; } +const anyFloatPrimitive = [f32, f16, abstractFloat]; +const anyFloatVec = [ + vec2f, + vec3f, + vec4f, + vec2h, + vec3h, + vec4h, +]; +const anyFloat = [...anyFloatPrimitive, ...anyFloatVec]; +const anyConcreteIntegerPrimitive = [i32, u32]; +const anyConcreteIntegerVec = [vec2i, vec3i, vec4i, vec2u, vec3u, vec4u]; +const anyConcreteInteger = [ + ...anyConcreteIntegerPrimitive, + ...anyConcreteIntegerVec, +]; + // std function cpuAbs(value: number): number; @@ -122,7 +170,7 @@ function cpuAcos(value: T): T { export const acos = dualImpl({ name: 'acos', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAcos, codegenImpl: (_ctx, [value]) => stitch`acos(${value})`, }); @@ -138,7 +186,7 @@ function cpuAcosh(value: T): T { export const acosh = dualImpl({ name: 'acosh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAcosh, codegenImpl: (_ctx, [value]) => stitch`acosh(${value})`, }); @@ -154,7 +202,7 @@ function cpuAsin(value: T): T { export const asin = dualImpl({ name: 'asin', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAsin, codegenImpl: (_ctx, [value]) => stitch`asin(${value})`, }); @@ -170,7 +218,7 @@ function cpuAsinh(value: T): T { export const asinh = dualImpl({ name: 'asinh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAsinh, codegenImpl: (_ctx, [value]) => stitch`asinh(${value})`, }); @@ -186,7 +234,7 @@ function cpuAtan(value: T): T { export const atan = dualImpl({ name: 'atan', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAtan, codegenImpl: (_ctx, [value]) => stitch`atan(${value})`, }); @@ -202,7 +250,7 @@ function cpuAtanh(value: T): T { export const atanh = dualImpl({ name: 'atanh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuAtanh, codegenImpl: (_ctx, [value]) => stitch`atanh(${value})`, }); @@ -221,13 +269,7 @@ function cpuAtan2(y: T, x: T): T { export const atan2 = dualImpl({ name: 'atan2', - signature: (...args) => { - const uargs = unify(args, [f32, f16, abstractFloat]) ?? args; - return ({ - argTypes: uargs, - returnType: uargs[0], - }); - }, + signature: variadicUnifyRestrictedSignature(anyFloat), normalImpl: cpuAtan2, codegenImpl: (_ctx, [y, x]) => stitch`atan2(${y}, ${x})`, }); @@ -243,7 +285,7 @@ function cpuCeil(value: T): T { export const ceil = dualImpl({ name: 'ceil', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuCeil, codegenImpl: (_ctx, [value]) => stitch`ceil(${value})`, }); @@ -280,7 +322,7 @@ function cpuCos(value: T): T { export const cos = dualImpl({ name: 'cos', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuCos, codegenImpl: (_ctx, [value]) => stitch`cos(${value})`, }); @@ -296,7 +338,7 @@ function cpuCosh(value: T): T { export const cosh = dualImpl({ name: 'cosh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuCosh, codegenImpl: (_ctx, [value]) => stitch`cosh(${value})`, }); @@ -311,7 +353,7 @@ function cpuCountLeadingZeros( export const countLeadingZeros = dualImpl({ name: 'countLeadingZeros', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countLeadingZeros not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countLeadingZeros(${value})`, @@ -327,7 +369,7 @@ function cpuCountOneBits( export const countOneBits = dualImpl({ name: 'countOneBits', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countOneBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countOneBits(${value})`, @@ -343,7 +385,7 @@ function cpuCountTrailingZeros( export const countTrailingZeros = dualImpl({ name: 'countTrailingZeros', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countTrailingZeros not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countTrailingZeros(${value})`, @@ -351,9 +393,7 @@ export const countTrailingZeros = dualImpl({ export const cross = dualImpl({ name: 'cross', - signature: (...args) => { - return ({ argTypes: args, returnType: args[0] }); - }, + signature: variadicUnifyRestrictedSignature([vec3f, vec3h]), normalImpl: (a: T, b: T): T => VectorOps.cross[a.kind](a, b), codegenImpl: (_ctx, [a, b]) => stitch`cross(${a}, ${b})`, @@ -372,15 +412,22 @@ function cpuDegrees(value: T): T { export const degrees = dualImpl({ name: 'degrees', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuDegrees, codegenImpl: (_ctx, [value]) => stitch`degrees(${value})`, }); export const determinant = dualImpl<(value: AnyMatInstance) => number>({ name: 'determinant', - // TODO: The return type is potentially wrong here, it should return whatever the matrix element type is. - signature: unaryIdentitySignature, + signature: (arg) => { + if ( + !(arg.type === 'mat2x2f' || arg.type === 'mat3x3f' || + arg.type === 'mat4x4f') + ) { + throw new SignatureNotSupportedError([arg], [mat2x2f, mat3x3f, mat4x4f]); + } + return { argTypes: [arg], returnType: f32 }; + }, normalImpl: 'CPU implementation for determinant not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`determinant(${value})`, @@ -403,10 +450,14 @@ function cpuDistance( export const distance = dualImpl({ name: 'distance', signature: (...args) => { - return ({ - argTypes: args, - returnType: isHalfPrecisionSchema(args[0]) ? f16 : f32, - }); + const uargs = unify(args, anyFloat); + if (!uargs) { + throw new SignatureNotSupportedError(args, anyFloat); + } + return { + argTypes: uargs, + returnType: isHalfPrecisionSchema(uargs[0]) ? f16 : f32, + }; }, normalImpl: cpuDistance, codegenImpl: (_ctx, [a, b]) => stitch`distance(${a}, ${b})`, @@ -450,7 +501,7 @@ function cpuExp(value: T): T { export const exp = dualImpl({ name: 'exp', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuExp, codegenImpl: (_ctx, [value]) => stitch`exp(${value})`, }); @@ -466,7 +517,7 @@ function cpuExp2(value: T): T { export const exp2 = dualImpl({ name: 'exp2', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuExp2, codegenImpl: (_ctx, [value]) => stitch`exp2(${value})`, }); @@ -487,10 +538,16 @@ function cpuExtractBits( export const extractBits = dualImpl({ name: 'extractBits', - signature: (arg, _offset, _count) => ({ - argTypes: [arg, u32, u32], - returnType: arg, - }), + signature: (arg, _offset, _count) => { + const argRestricted = unify([arg], anyConcreteInteger)?.[0]; + if (!argRestricted) { + throw new SignatureNotSupportedError([arg], anyConcreteInteger); + } + return { + argTypes: [argRestricted, u32, u32], + returnType: argRestricted, + }; + }, normalImpl: 'CPU implementation for extractBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [e, offset, count]) => @@ -501,12 +558,7 @@ export const faceForward = dualImpl< (e1: T, e2: T, e3: T) => T >({ name: 'faceForward', - signature: (...args) => { - return ({ - argTypes: args, - returnType: args[0], - }); - }, + signature: variadicUnifyRestrictedSignature(anyFloatVec), normalImpl: 'CPU implementation for faceForward not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [e1, e2, e3]) => stitch`faceForward(${e1}, ${e2}, ${e3})`, @@ -538,7 +590,7 @@ function cpuFirstTrailingBit( export const firstTrailingBit = dualImpl({ name: 'firstTrailingBit', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for firstTrailingBit not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`firstTrailingBit(${value})`, @@ -555,7 +607,7 @@ function cpuFloor(value: T): T { export const floor = dualImpl({ name: 'floor', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuFloor, codegenImpl: (_ctx, [arg]) => stitch`floor(${arg})`, }); @@ -577,10 +629,7 @@ function cpuFma( export const fma = dualImpl({ name: 'fma', - signature: (...args) => ({ - argTypes: args, - returnType: args[0], - }), + signature: variadicUnifyRestrictedSignature(anyFloat), normalImpl: cpuFma, codegenImpl: (_ctx, [e1, e2, e3]) => stitch`fma(${e1}, ${e2}, ${e3})`, }); @@ -596,7 +645,7 @@ function cpuFract(value: T): T { export const fract = dualImpl({ name: 'fract', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuFract, codegenImpl: (_ctx, [a]) => stitch`fract(${a})`, }); @@ -628,9 +677,7 @@ export const frexp = dualImpl({ const returnType = FrexpResults[value.type as keyof typeof FrexpResults]; if (!returnType) { - throw new Error( - `Unsupported data type for frexp: ${value.type}. Supported types are f32, f16, abstractFloat, vec2f, vec3f, vec4f, vec2h, vec3h, vec4h.`, - ); + throw new SignatureNotSupportedError([value], anyFloat); } return { argTypes: [value], returnType }; @@ -661,10 +708,16 @@ function cpuInsertBits( export const insertBits = dualImpl({ name: 'insertBits', - signature: (e, newbits, _offset, _count) => ({ - argTypes: [e, newbits, u32, u32], - returnType: e, - }), + signature: (e, newbits, _offset, _count) => { + const uargs = unify([e, newbits], anyConcreteInteger); + if (!uargs) { + throw new SignatureNotSupportedError([e, newbits], anyConcreteInteger); + } + return { + argTypes: [...uargs, u32, u32], + returnType: uargs[0] as AnyData, + }; + }, normalImpl: 'CPU implementation for insertBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [e, newbits, offset, count]) => @@ -684,7 +737,7 @@ function cpuInverseSqrt(value: T): T { export const inverseSqrt = dualImpl({ name: 'inverseSqrt', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuInverseSqrt, codegenImpl: (_ctx, [value]) => stitch`inverseSqrt(${value})`, }); @@ -740,10 +793,16 @@ function cpuLength(value: T): number { export const length = dualImpl({ name: 'length', - signature: (arg) => ({ - argTypes: [arg], - returnType: isHalfPrecisionSchema(arg) ? f16 : f32, - }), + signature: (arg) => { + const uarg = unify([arg], anyFloat); + if (!uarg) { + throw new SignatureNotSupportedError([arg], anyFloat); + } + return { + argTypes: uarg, + returnType: isHalfPrecisionSchema(uarg[0]) ? f16 : f32, + }; + }, normalImpl: cpuLength, codegenImpl: (_ctx, [arg]) => stitch`length(${arg})`, }); @@ -759,7 +818,7 @@ function cpuLog(value: T): T { export const log = dualImpl({ name: 'log', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuLog, codegenImpl: (_ctx, [value]) => stitch`log(${value})`, }); @@ -775,7 +834,7 @@ function cpuLog2(value: T): T { export const log2 = dualImpl({ name: 'log2', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuLog2, codegenImpl: (_ctx, [value]) => stitch`log2(${value})`, }); @@ -843,7 +902,22 @@ function cpuMix( export const mix = dualImpl({ name: 'mix', - signature: variadicUnifySignature, + signature: (...[e1, e2, e3]) => { + if (e1.type.startsWith('vec') && !e3.type.startsWith('vec')) { + const uarg = unify([e3], [(e1 as unknown as Vec2f).primitive]); + if (!uarg) { + throw new SignatureNotSupportedError([e3], [ + (e1 as unknown as Vec2f).primitive, + ]); + } + return { argTypes: [e1, e2, uarg[0] as AnyData], returnType: e1 }; + } + const uargs = unify([e1, e2, e3], anyFloat); + if (!uargs) { + throw new SignatureNotSupportedError([e1, e2, e3], anyFloat); + } + return { argTypes: uargs, returnType: uargs[0] as AnyData }; + }, normalImpl: cpuMix, codegenImpl: (_ctx, [e1, e2, e3]) => stitch`mix(${e1}, ${e2}, ${e3})`, }); @@ -896,7 +970,7 @@ export const modf: ModfOverload = dualImpl({ export const normalize = dualImpl({ name: 'normalize', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloatVec), normalImpl: (v: T): T => VectorOps.normalize[v.kind](v), codegenImpl: (_ctx, [value]) => stitch`normalize(${value})`, @@ -922,13 +996,7 @@ function powCpu( export const pow = dualImpl({ name: 'pow', - signature: (...args) => { - const uargs = unify(args, [f32, f16, abstractFloat]) ?? args; - return { - argTypes: uargs, - returnType: isNumericSchema(uargs[0]) ? uargs[1] : uargs[0], - }; - }, + signature: variadicUnifyRestrictedSignature(anyFloat), normalImpl: powCpu, codegenImpl: (_ctx, [lhs, rhs]) => stitch`pow(${lhs}, ${rhs})`, }); @@ -942,7 +1010,14 @@ function cpuQuantizeToF16( export const quantizeToF16 = dualImpl({ name: 'quantizeToF16', - signature: unaryIdentitySignature, + signature: (arg) => { + const candidates = [vec2f, vec3f, vec4f, f32]; + const uarg = unify([arg], candidates)?.[0]; + if (!uarg) { + throw new SignatureNotSupportedError([arg], candidates); + } + return { argTypes: [uarg], returnType: uarg }; + }, normalImpl: 'CPU implementation for quantizeToF16 not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`quantizeToF16(${value})`, @@ -961,17 +1036,23 @@ function cpuRadians(value: T): T { export const radians = dualImpl({ name: 'radians', - signature: (...args) => { - const uargs = unify(args, [f32, f16, abstractFloat]) ?? args; - return ({ argTypes: uargs, returnType: uargs[0] }); - }, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuRadians, codegenImpl: (_ctx, [value]) => stitch`radians(${value})`, }); export const reflect = dualImpl({ name: 'reflect', - signature: (...args) => ({ argTypes: args, returnType: args[0] }), + signature: (...args) => { + const uargs = unify(args, anyFloatVec); + if (!uargs) { + throw new SignatureNotSupportedError(args, anyFloatVec); + } + return { + argTypes: uargs, + returnType: uargs[0], + }; + }, normalImpl: (e1: T, e2: T): T => sub(e1, mul(2 * dot(e2, e1), e2)), codegenImpl: (_ctx, [e1, e2]) => stitch`reflect(${e1}, ${e2})`, @@ -986,8 +1067,8 @@ export const refract = dualImpl< codegenImpl: (_ctx, [e1, e2, e3]) => stitch`refract(${e1}, ${e2}, ${e3})`, signature: (e1, e2, _e3) => ({ argTypes: [ - e1 as AnyWgslData, - e2 as AnyWgslData, + e1, + e2, isHalfPrecisionSchema(e1) ? f16 : f32, ], returnType: e1, @@ -1001,7 +1082,7 @@ function cpuReverseBits(value: T): T { export const reverseBits = dualImpl({ name: 'reverseBits', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for reverseBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`reverseBits(${value})`, @@ -1020,7 +1101,7 @@ function cpuRound(value: T): T { export const round = dualImpl({ name: 'round', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuRound, codegenImpl: (_ctx, [value]) => stitch`round(${value})`, }); @@ -1038,7 +1119,7 @@ function cpuSaturate(value: T): T { export const saturate = dualImpl({ name: 'saturate', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuSaturate, codegenImpl: (_ctx, [value]) => stitch`saturate(${value})`, }); @@ -1054,7 +1135,14 @@ function cpuSign(e: T): T { export const sign = dualImpl({ name: 'sign', - signature: unaryIdentitySignature, + signature: (arg) => { + const candidates = [...anyFloat, i32, vec2i, vec3i, vec4i]; + const uarg = unify([arg], candidates)?.[0]; + if (!uarg) { + throw new SignatureNotSupportedError([arg], candidates); + } + return { argTypes: [uarg], returnType: uarg }; + }, normalImpl: cpuSign, codegenImpl: (_ctx, [e]) => stitch`sign(${e})`, }); @@ -1070,7 +1158,7 @@ function cpuSin(value: T): T { export const sin = dualImpl({ name: 'sin', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuSin, codegenImpl: (_ctx, [value]) => stitch`sin(${value})`, }); @@ -1088,7 +1176,7 @@ function cpuSinh(value: T): T { export const sinh = dualImpl({ name: 'sinh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuSinh, codegenImpl: (_ctx, [value]) => stitch`sinh(${value})`, }); @@ -1120,10 +1208,7 @@ function cpuSmoothstep( export const smoothstep = dualImpl({ name: 'smoothstep', - signature: (...args) => ({ - argTypes: args, - returnType: args[2], - }), + signature: variadicUnifyRestrictedSignature(anyFloat), normalImpl: cpuSmoothstep, codegenImpl: (_ctx, [edge0, edge1, x]) => stitch`smoothstep(${edge0}, ${edge1}, ${x})`, @@ -1140,7 +1225,7 @@ function cpuSqrt(value: T): T { export const sqrt = dualImpl({ name: 'sqrt', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuSqrt, codegenImpl: (_ctx, [value]) => stitch`sqrt(${value})`, }); @@ -1158,10 +1243,7 @@ function cpuStep(edge: T, x: T): T { export const step = dualImpl({ name: 'step', - signature: (...args) => { - const uargs = unify(args, [f32, f16, abstractFloat]) ?? args; - return { argTypes: uargs, returnType: uargs[0] }; - }, + signature: variadicUnifyRestrictedSignature(anyFloat), normalImpl: cpuStep, codegenImpl: (_ctx, [edge, x]) => stitch`step(${edge}, ${x})`, }); @@ -1179,7 +1261,7 @@ function cpuTan(value: T): T { export const tan = dualImpl({ name: 'tan', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuTan, codegenImpl: (_ctx, [value]) => stitch`tan(${value})`, }); @@ -1195,7 +1277,7 @@ function cpuTanh(value: T): T { export const tanh = dualImpl({ name: 'tanh', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: cpuTanh, codegenImpl: (_ctx, [value]) => stitch`tanh(${value})`, }); @@ -1216,7 +1298,7 @@ function cpuTrunc(value: T): T { export const trunc = dualImpl({ name: 'trunc', - signature: unaryIdentitySignature, + signature: unaryIdentityRestrictedSignature(anyFloat), normalImpl: 'CPU implementation for trunc not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`trunc(${value})`, diff --git a/packages/typegpu/src/std/operators.ts b/packages/typegpu/src/std/operators.ts index 1c0081e031..50418023fc 100644 --- a/packages/typegpu/src/std/operators.ts +++ b/packages/typegpu/src/std/operators.ts @@ -6,18 +6,82 @@ import { VectorOps } from '../data/vectorOps.ts'; import { type AnyMatInstance, type AnyNumericVecInstance, + type BaseData, isFloat32VecInstance, + isMat, isMatInstance, - isNumericSchema, + isVec, isVecInstance, type mBaseForVec, type vBaseForMat, } from '../data/wgslTypes.ts'; +import { SignatureNotSupportedError } from '../errors.ts'; import { unify } from '../tgsl/conversion.ts'; type NumVec = AnyNumericVecInstance; type Mat = AnyMatInstance; +const getPrimitive = (t: BaseData): BaseData => + 'primitive' in t ? (t.primitive as BaseData) : t; + +const makeBinarySignature = (opts?: { + matVecProduct?: boolean; + noMat?: boolean; + restrict?: BaseData[]; +}) => +(lhs: BaseData, rhs: BaseData) => { + const { restrict } = opts ?? {}; + const fail = (msg: string): never => { + if (restrict) { + throw new SignatureNotSupportedError([lhs, rhs], restrict); + } + throw new Error( + `Cannot apply operator to ${lhs.type} and ${rhs.type}: ${msg}`, + ); + }; + + if (opts?.noMat && (isMat(lhs) || isMat(rhs))) { + return fail('matrices not supported'); + } + const lhsC = isVec(lhs) || isMat(lhs); + const rhsC = isVec(rhs) || isMat(rhs); + + if (!lhsC && !rhsC) { + // scalar × scalar + const unified = unify([lhs, rhs], restrict); + if (!unified) return fail('incompatible scalar types'); + return { argTypes: unified, returnType: unified[0] }; + } + + if (lhsC && rhsC) { + // vec × mat or mat × vec + if (opts?.matVecProduct && isVec(lhs) !== isVec(rhs)) { + return { argTypes: [lhs, rhs], returnType: isVec(lhs) ? lhs : rhs }; + } + // composite × composite (same kind) + if (lhs.type !== rhs.type) return fail('operands must have the same type'); + return { argTypes: [lhs, rhs], returnType: lhs }; + } + + // scalar × composite + const [scalar, composite] = lhsC ? [rhs, lhs] : [lhs, rhs]; + const unified = unify([scalar], [getPrimitive(composite)]); + if (!unified) { + return fail(`scalar not convertible to ${getPrimitive(composite).type}`); + } + return { + argTypes: lhsC ? [lhs, unified[0]] : [unified[0], rhs], + returnType: composite, + }; +}; + +const binaryArithmeticSignature = makeBinarySignature(); +const binaryMulSignature = makeBinarySignature({ matVecProduct: true }); +const binaryDivSignature = makeBinarySignature({ + noMat: true, + restrict: [f32, f16, abstractFloat], +}); + function cpuAdd(lhs: number, rhs: number): number; // default addition function cpuAdd(lhs: number, rhs: T): T; // mixed addition function cpuAdd(lhs: T, rhs: number): T; // mixed addition @@ -52,13 +116,7 @@ function cpuAdd(lhs: number | NumVec | Mat, rhs: number | NumVec | Mat) { export const add = dualImpl({ name: 'add', - signature: (...args) => { - const uargs = unify(args) ?? args; - return { - argTypes: uargs, - returnType: isNumericSchema(uargs[0]) ? uargs[1] : uargs[0], - }; - }, + signature: binaryArithmeticSignature, normalImpl: cpuAdd, codegenImpl: (_ctx, [lhs, rhs]) => stitch`(${lhs} + ${rhs})`, }); @@ -82,13 +140,7 @@ function cpuSub(lhs: number | NumVec | Mat, rhs: number | NumVec | Mat) { export const sub = dualImpl({ name: 'sub', - signature: (...args) => { - const uargs = unify(args) ?? args; - return { - argTypes: uargs, - returnType: isNumericSchema(uargs[0]) ? uargs[1] : uargs[0], - }; - }, + signature: binaryArithmeticSignature, normalImpl: cpuSub, codegenImpl: (_ctx, [lhs, rhs]) => stitch`(${lhs} - ${rhs})`, }); @@ -136,25 +188,7 @@ function cpuMul(lhs: number | NumVec | Mat, rhs: number | NumVec | Mat) { export const mul = dualImpl({ name: 'mul', - signature: (...args) => { - const uargs = unify(args) ?? args; - const returnType = isNumericSchema(uargs[0]) - // Scalar * Scalar/Vector/Matrix - ? uargs[1] - : isNumericSchema(uargs[1]) - // Vector/Matrix * Scalar - ? uargs[0] - : uargs[0].type.startsWith('vec') - // Vector * Vector/Matrix - ? uargs[0] - : uargs[1].type.startsWith('vec') - // Matrix * Vector - ? uargs[1] - // Matrix * Matrix - : uargs[0]; - - return ({ argTypes: uargs, returnType }); - }, + signature: binaryMulSignature, normalImpl: cpuMul, codegenImpl: (_ctx, [lhs, rhs]) => stitch`(${lhs} * ${rhs})`, }); @@ -183,13 +217,7 @@ function cpuDiv(lhs: NumVec | number, rhs: NumVec | number): NumVec | number { export const div = dualImpl({ name: 'div', - signature: (...args) => { - const uargs = unify(args, [f32, f16, abstractFloat]) ?? args; - return ({ - argTypes: uargs, - returnType: isNumericSchema(uargs[0]) ? uargs[1] : uargs[0], - }); - }, + signature: binaryDivSignature, normalImpl: cpuDiv, codegenImpl: (_ctx, [lhs, rhs]) => stitch`(${lhs} / ${rhs})`, ignoreImplicitCastWarning: true, @@ -208,13 +236,7 @@ type ModOverload = { */ export const mod: ModOverload = dualImpl({ name: 'mod', - signature: (...args) => { - const uargs = unify(args) ?? args; - return { - argTypes: uargs, - returnType: isNumericSchema(uargs[0]) ? uargs[1] : uargs[0], - }; - }, + signature: binaryArithmeticSignature, normalImpl(a: T, b: T): T { if (typeof a === 'number' && typeof b === 'number') { return (a % b) as T; // scalar % scalar diff --git a/packages/typegpu/src/std/texture.ts b/packages/typegpu/src/std/texture.ts index f1b1d1c055..069d18327a 100644 --- a/packages/typegpu/src/std/texture.ts +++ b/packages/typegpu/src/std/texture.ts @@ -7,7 +7,7 @@ import { } from '../data/texture.ts'; import type { TexelData } from '../core/texture/texture.ts'; import { dualImpl, MissingCpuImplError } from '../core/function/dualImpl.ts'; -import { f32, u32 } from '../data/numeric.ts'; +import { f32, i32, u32 } from '../data/numeric.ts'; import { vec2u, vec3u, vec4f, vec4i, vec4u } from '../data/vector.ts'; import { type BaseData, @@ -504,6 +504,136 @@ export const textureDimensions = dualImpl({ }, }); +type Gather2dArgs = [ + component: number, + texture: T, + sampler: sampler, + coords: v2f, + offset?: v2i, +]; +type Gather2dArrayArgs = [ + component: number, + texture: T, + sampler: sampler, + coords: v2f, + arrayIndex: number, + offset?: v2i, +]; +type GatherCubeArgs = [ + component: number, + texture: T, + sampler: sampler, + coords: v3f, +]; +type GatherCubeArrayArgs = [ + component: number, + texture: T, + sampler: sampler, + coords: v3f, + arrayIndex: number, +]; +type GatherDepth2dArgs = [ + texture: textureDepth2d, + sampler: sampler, + coords: v2f, + offset?: v2i, +]; +type GatherDepth2dArrayArgs = [ + texture: textureDepth2dArray, + sampler: sampler, + coords: v2f, + arrayIndex: number, + offset?: v2i, +]; +type GatherDepthCubeArgs = [ + texture: textureDepthCube, + sampler: sampler, + coords: v3f, +]; +type GatherDepthCubeArrayArgs = [ + texture: textureDepthCubeArray, + sampler: sampler, + coords: v3f, + arrayIndex: number, +]; + +type TextureGatherCpuArgs = + | Gather2dArgs + | Gather2dArrayArgs + | GatherCubeArgs + | GatherCubeArrayArgs + | GatherDepth2dArgs + | GatherDepth2dArrayArgs + | GatherDepthCubeArgs + | GatherDepthCubeArrayArgs; + +type TextureGatherCpuFn = { + ( + ...args: Gather2dArgs + ): PrimitiveToLoadedType[T[typeof $internal]['type']]; + ( + ...args: Gather2dArrayArgs + ): PrimitiveToLoadedType[T[typeof $internal]['type']]; + ( + ...args: GatherCubeArgs + ): PrimitiveToLoadedType[T[typeof $internal]['type']]; + ( + ...args: GatherCubeArrayArgs + ): PrimitiveToLoadedType[T[typeof $internal]['type']]; + (...args: GatherDepth2dArgs): v4f; + (...args: GatherDepth2dArrayArgs): v4f; + (...args: GatherDepthCubeArgs): v4f; + (...args: GatherDepthCubeArrayArgs): v4f; +}; + +export const textureGatherCpu: TextureGatherCpuFn = ( + ...args: TextureGatherCpuArgs +): v4f => { + throw new Error( + 'Texture gather relies on GPU resources and cannot be executed outside of a draw call', + ); +}; + +const sampleTypeToVecType = { + f32: vec4f, + i32: vec4i, + u32: vec4u, +}; + +export const textureGather = dualImpl({ + name: 'textureGather', + normalImpl: textureGatherCpu, + codegenImpl: (_ctx, args) => stitch`textureGather(${args})`, + signature: (...args) => { + if (args[0].type.startsWith('texture')) { + const [texture, sampler, coords, _, ...rest] = args; + + const isArrayTexture = texture.type === 'texture_depth_2d_array' || + texture.type === 'texture_depth_cube_array'; + + const argTypes = isArrayTexture + ? [texture, sampler, coords, [u32, i32], ...rest] + : args as BaseData[]; + + return { argTypes: argTypes as BaseData[], returnType: vec4f }; + } + + const [_, texture, sampler, coords, ...rest] = args; + + const isArrayTexture = texture.type === 'texture_2d_array' || + texture.type === 'texture_cube_array'; + + const argTypes = isArrayTexture + ? [[u32, i32], texture, sampler, coords, [u32, i32], ...rest] + : [[u32, i32], texture, sampler, coords, ...rest]; + + return { + argTypes: argTypes as BaseData[], + returnType: sampleTypeToVecType[(texture as WgslTexture).sampleType.type], + }; + }, +}); + function textureSampleCompareCpu( texture: T, sampler: comparisonSampler, diff --git a/packages/typegpu/src/tgsl/conversion.ts b/packages/typegpu/src/tgsl/conversion.ts index 0dbae8af35..0b5e2f9664 100644 --- a/packages/typegpu/src/tgsl/conversion.ts +++ b/packages/typegpu/src/tgsl/conversion.ts @@ -5,6 +5,7 @@ import { derefSnippet, RefOperator } from '../data/ref.ts'; import { schemaCallWrapperGPU } from '../data/schemaCallWrapper.ts'; import { snip, type Snippet } from '../data/snippet.ts'; import { + type AnyWgslData, type BaseData, type F16, type F32, @@ -57,11 +58,19 @@ function getAutoConversionRank( if (trueDst.type === 'f16') return { rank: 7, action: 'none' }; } - if (isVec(trueSrc) && isVec(trueDst)) { + if ( + isVec(trueSrc) && isVec(trueDst) && + // Same length vectors + trueSrc.type[3] === trueDst.type[3] + ) { return getAutoConversionRank(trueSrc.primitive, trueDst.primitive); } - if (isMat(trueSrc) && isMat(trueDst)) { + if ( + isMat(trueSrc) && isMat(trueDst) && + // Same dimensions + trueSrc.type[3] === trueDst.type[3] + ) { // Matrix conversion rank depends only on component type (always f32 for now) return { rank: 0, action: 'none' }; } @@ -272,11 +281,7 @@ export function unify( return undefined; } - const primitiveTypes = inTypes.map((type) => - isVec(type) || isMat(type) ? type.primitive : type as BaseData - ); - - const conversion = getBestConversion(primitiveTypes, restrictTo); + const conversion = getBestConversion(inTypes as BaseData[], restrictTo); if (!conversion) { return undefined; } @@ -332,38 +337,40 @@ Consider using explicit conversions instead.`, export function tryConvertSnippet( ctx: ResolutionCtx, snippet: Snippet, - targetDataType: BaseData, + targetDataTypes: BaseData | BaseData[], verbose = true, ): Snippet { - if (targetDataType === snippet.dataType) { - return snip(snippet.value, targetDataType, snippet.origin); - } + const targets = Array.isArray(targetDataTypes) + ? targetDataTypes + : [targetDataTypes]; - if (snippet.dataType === UnknownData) { - // This is it, it's now or never. We expect a specific type, and we're going to get it - return snip( - stitch`${snip(snippet.value, targetDataType, snippet.origin)}`, - targetDataType, - snippet.origin, - ); - } + const { value, dataType, origin } = snippet; - const converted = convertToCommonType( - ctx, - [snippet], - [targetDataType], - verbose, - ); + if (targets.length === 1) { + const target = targets[0] as AnyWgslData; - if (!converted) { - throw new WgslTypeError( - `Cannot convert value of type '${ - String(snippet.dataType) - }' to type '${targetDataType.type}'`, - ); + if (target === dataType) { + return snip(value, target, origin); + } + + if (typeof dataType === 'symbol') { + // Commit unknown to the expected type. + return snip(stitch`${snip(value, target, origin)}`, target, origin); + } + } + + const converted = convertToCommonType(ctx, [snippet], targets, verbose); + if (converted) { + return converted[0] as Snippet; } - return converted[0] as Snippet; + throw new WgslTypeError( + `Cannot convert value of type '${ + String( + dataType, + ) + }' to any of the target types: [${targets.map((t) => t.type).join(', ')}]`, + ); } export function convertStructValues( diff --git a/packages/typegpu/src/tgsl/generationHelpers.ts b/packages/typegpu/src/tgsl/generationHelpers.ts index 228daeca41..3d8333a321 100644 --- a/packages/typegpu/src/tgsl/generationHelpers.ts +++ b/packages/typegpu/src/tgsl/generationHelpers.ts @@ -47,16 +47,18 @@ export function concretize(type: T): T | F32 | I32 { return type; } -export function concretizeSnippets(args: Snippet[]): Snippet[] { - return args.map((snippet) => - snip( - snippet.value, - concretize(snippet.dataType as AnyWgslData), - /* origin */ snippet.origin, - ) +export function concretizeSnippet(snippet: Snippet): Snippet { + return snip( + snippet.value, + concretize(snippet.dataType as AnyWgslData), + snippet.origin, ); } +export function concretizeSnippets(args: Snippet[]): Snippet[] { + return args.map(concretizeSnippet); +} + export type GenerationCtx = ResolutionCtx & { readonly pre: string; /** @@ -66,7 +68,7 @@ export type GenerationCtx = ResolutionCtx & { * It is used exclusively for inferring the types of structs and arrays. * It is modified exclusively by `typedExpression` function. */ - expectedType: BaseData | undefined; + expectedType: (BaseData | BaseData[]) | undefined; readonly topFunctionScope: FunctionScopeLayer | undefined; readonly topFunctionReturnType: BaseData | undefined; diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 5a3691a550..dde1be7a79 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -136,7 +136,7 @@ function operatorToType< TR extends wgsl.BaseData | UnknownData, >(lhs: TL, op: Operator, rhs?: TR): TL | TR | wgsl.Bool { if (!rhs) { - if (op === '!' || op === '~') { + if (op === '!') { return bool; } @@ -277,7 +277,7 @@ ${this.ctx.pre}}`; */ public typedExpression( expression: tinyest.Expression, - expectedType: wgsl.BaseData, + expectedType: wgsl.BaseData | wgsl.BaseData[], ) { const prevExpectedType = this.ctx.expectedType; this.ctx.expectedType = expectedType; diff --git a/packages/typegpu/src/types.ts b/packages/typegpu/src/types.ts index ea0a13e700..6174abde71 100644 --- a/packages/typegpu/src/types.ts +++ b/packages/typegpu/src/types.ts @@ -361,7 +361,7 @@ export function getOwnSnippet(value: unknown): Snippet | undefined { export interface GPUCallable { [$gpuCallable]: { strictSignature?: - | { argTypes: BaseData[]; returnType: BaseData } + | { argTypes: (BaseData | BaseData[])[]; returnType: BaseData } | undefined; call(ctx: ResolutionCtx, args: MapValueToSnippet): Snippet; }; diff --git a/packages/typegpu/tests/examples/individual/disco.test.ts b/packages/typegpu/tests/examples/individual/disco.test.ts index ab46090251..6a50324011 100644 --- a/packages/typegpu/tests/examples/individual/disco.test.ts +++ b/packages/typegpu/tests/examples/individual/disco.test.ts @@ -76,7 +76,7 @@ describe('disco example', () => { var paletteColor = palette((length(originalUv) + (time * 0.9f))); radialLength = (sin(((radialLength * 8f) + time)) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.1, radialLength); + radialLength = smoothstep(0f, 0.1f, radialLength); radialLength = (0.1f / radialLength); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); } @@ -105,7 +105,7 @@ describe('disco example', () => { var paletteColor = palette(((length(originalUv) + (time * 0.8f)) + (iterationF32 * 0.05f))); radialLength = (sin(((radialLength * 7f) + (time * 0.9f))) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.11, radialLength); + radialLength = smoothstep(0f, 0.11f, radialLength); radialLength = (0.055f / (radialLength + 1e-5f)); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); } @@ -135,7 +135,7 @@ describe('disco example', () => { var radialLength = (length(aspectUv) * exp((-(length(originalUv)) * (1.3f + (iterationF32 * 0.06f))))); radialLength = (sin(((radialLength * (7.2f + (iterationF32 * 0.8f))) + (time_1 * (1.1f + (iterationF32 * 0.2f))))) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.105, radialLength); + radialLength = smoothstep(0f, 0.105f, radialLength); radialLength = ((0.058f + (iterationF32 * 6e-3f)) / (radialLength + 1e-5f)); var paletteColor = palette(((length(originalUv) + (time_1 * 0.65f)) + (iterationF32 * 0.045f))); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); @@ -165,7 +165,7 @@ describe('disco example', () => { var paletteColor = palette(((length(originalUv) + (time * 0.9f)) + (iterationF32 * 0.08f))); radialLength = (sin(((radialLength * (6f + iterationF32)) + time)) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.1, radialLength); + radialLength = smoothstep(0f, 0.1f, radialLength); radialLength = ((0.085f + (iterationF32 * 5e-3f)) / (radialLength + 1e-5f)); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); } @@ -193,7 +193,7 @@ describe('disco example', () => { var radialLength = (length(warpedUv) * exp((-(length(originalUv)) * (1.4f + (iterationF32 * 0.05f))))); radialLength = (sin(((radialLength * (7f + (iterationF32 * 0.7f))) + (time_1 * (0.9f + (iterationF32 * 0.15f))))) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.1, radialLength); + radialLength = smoothstep(0f, 0.1f, radialLength); radialLength = ((0.05f + (iterationF32 * 5e-3f)) / (radialLength + 1e-5f)); var paletteColor = palette(((length(originalUv) + (time_1 * 0.7f)) + (iterationF32 * 0.04f))); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); @@ -225,7 +225,7 @@ describe('disco example', () => { var radialLength = (length(aspectUv) * exp((-(length(originalUv)) * (1.2f + (iterationF32 * 0.08f))))); radialLength = (sin(((radialLength * (7.5f + iterationF32)) + (time_1 * (1f + (iterationF32 * 0.1f))))) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.11, radialLength); + radialLength = smoothstep(0f, 0.11f, radialLength); radialLength = ((0.06f + (iterationF32 * 5e-3f)) / (radialLength + 1e-5f)); var paletteColor = palette(((length(originalUv) + (time_1 * 0.75f)) + (iterationF32 * 0.05f))); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); @@ -290,7 +290,7 @@ describe('disco example', () => { var radialLength = (length(aspectUv) * exp((-(length(originalUv)) * 2f))); radialLength = (sin(((radialLength * 8f) + time)) / 8f); radialLength = abs(radialLength); - radialLength = smoothstep(0, 0.1, radialLength); + radialLength = smoothstep(0f, 0.1f, radialLength); radialLength = (0.06f / radialLength); var paletteColor = palette((length(originalUv) + (time * 0.9f))); accumulatedColor = accumulate(accumulatedColor, paletteColor, radialLength); diff --git a/packages/typegpu/tests/examples/individual/image-tuning.test.ts b/packages/typegpu/tests/examples/individual/image-tuning.test.ts index 5e74ce5fdb..339a67be0c 100644 --- a/packages/typegpu/tests/examples/individual/image-tuning.test.ts +++ b/packages/typegpu/tests/examples/individual/image-tuning.test.ts @@ -87,7 +87,7 @@ describe('image tuning example', () => { let highlightShift = (adjustments.highlights - 1f); let highlightBiased = select((highlightShift * 0.25f), highlightShift, (adjustments.highlights >= 1f)); let highlightFactor = (1f + ((highlightBiased * 0.5f) * contrastColorLuminance)); - let highlightWeight = smoothstep(0.5, 1, contrastColorLuminance); + let highlightWeight = smoothstep(0.5f, 1f, contrastColorLuminance); let highlightLuminanceAdjust = (contrastLuminance * highlightFactor); let highlightLuminance = mix(contrastLuminance, saturate(highlightLuminanceAdjust), highlightWeight); var highlightColor = mix(contrastColor, saturate((contrastColor * highlightFactor)), highlightWeight); diff --git a/packages/typegpu/tests/examples/individual/jelly-slider.test.ts b/packages/typegpu/tests/examples/individual/jelly-slider.test.ts index a0901b2cb6..685949c1db 100644 --- a/packages/typegpu/tests/examples/individual/jelly-slider.test.ts +++ b/packages/typegpu/tests/examples/individual/jelly-slider.test.ts @@ -341,7 +341,7 @@ describe('jelly-slider example', () => { var centeredUV = vec2f((uvX_orig - 0.5f), (uvZ_orig - 0.5f)); var finalUV = vec2f(centeredUV.x, (1f - (pow((abs((centeredUV.y - 0.5f)) * 2f), 2f) * 0.3f))); let density = max(0f, ((textureSampleLevel(bezierTexture, filteringSampler, finalUV, 0).x - 0.25f) * 8f)); - let fadeX = smoothstep(0, -0.2, (hitPosition.x - endCapX)); + let fadeX = smoothstep(0f, -0.2f, (hitPosition.x - endCapX)); let fadeZ = (1f - pow((abs((centeredUV.y - 0.5f)) * 2f), 3f)); let fadeStretch = saturate((1f - sliderStretch)); let edgeFade = ((saturate(fadeX) * saturate(fadeZ)) * fadeStretch); @@ -467,7 +467,7 @@ describe('jelly-slider example', () => { let zDirection = sign(position.z); var zAxisVector = vec3f(0f, 0f, zDirection); let edgeBlendDistance = ((edgeContrib * 0.024f) + (zContrib * 0.17f)); - let blendFactor = smoothstep(edgeBlendDistance, 0, ((zDistance * zContrib) + (edgeDistance * edgeContrib))); + let blendFactor = smoothstep(edgeBlendDistance, 0f, ((zDistance * zContrib) + (edgeDistance * edgeContrib))); var normal2D = vec3f((*gradient2D).xy, 0f); var blendedNormal = mix(zAxisVector, normal2D, ((blendFactor * 0.5f) + 0.5f)); var normal = normalize(blendedNormal); diff --git a/packages/typegpu/tests/examples/individual/jelly-switch.test.ts b/packages/typegpu/tests/examples/individual/jelly-switch.test.ts index 69dfb5be32..9775631c48 100644 --- a/packages/typegpu/tests/examples/individual/jelly-switch.test.ts +++ b/packages/typegpu/tests/examples/individual/jelly-switch.test.ts @@ -265,7 +265,7 @@ describe('jelly switch example', () => { let sqDist = sqLength((hitPosition - vec3f(switchX, 0f, 0f))); var bounceLight = ((*jellyColor).xyz * ((1f / ((sqDist * 15f) + 1f)) * 0.4f)); var sideBounceLight = (((*jellyColor).xyz * ((1f / ((sqDist * 40f) + 1f)) * 0.3f)) * abs(newNormal.z)); - let emission = ((smoothstep(0.7, 1, (*state).progress) * 2f) + 0.7f); + let emission = ((smoothstep(0.7f, 1f, (*state).progress) * 2f) + 0.7f); var litColor = calculateLighting(hitPosition, newNormal, rayOrigin); var backgroundColor = ((applyAO((select(vec3f(1), vec3f(0.20000000298023224), (darkModeUniform == 1u)) * litColor), hitPosition, newNormal) + vec4f((bounceLight * emission), 0f)) + vec4f((sideBounceLight * emission), 0f)); return vec4f(backgroundColor.xyz, 1f); diff --git a/packages/typegpu/tests/examples/individual/jump-flood-distance.test.ts b/packages/typegpu/tests/examples/individual/jump-flood-distance.test.ts index 6472bc2d55..41b9960c3a 100644 --- a/packages/typegpu/tests/examples/individual/jump-flood-distance.test.ts +++ b/packages/typegpu/tests/examples/individual/jump-flood-distance.test.ts @@ -205,7 +205,7 @@ describe('jump flood (distance) example', () => { baseColor = insideBase; } let contourFreq = (maxDist / 12f); - let contour = smoothstep(0, 0.15, abs((fract((unsigned / contourFreq)) - 0.5f))); + let contour = smoothstep(0f, 0.15f, abs((fract((unsigned / contourFreq)) - 0.5f))); var color = (baseColor * (0.7f + (0.3f * contour))); return vec4f(color, 1f); }" diff --git a/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts b/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts index 3106c96966..40a3b2632d 100644 --- a/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts +++ b/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts @@ -124,11 +124,11 @@ describe('tgsl parsing test example', () => { fn arrayAndStructConstructorsTest() -> bool { var s = true; var defaultComplexStruct = ComplexStruct(); - s = (s && (2 == 2)); + s = (s && (2 == 2i)); s = (s && (defaultComplexStruct.arr[0i] == 0i)); s = (s && (defaultComplexStruct.arr[1i] == 0i)); var defaultComplexArray = array(); - s = (s && (3 == 3)); + s = (s && (3 == 3i)); s = (s && all(defaultComplexArray[0i].vec == vec2f())); s = (s && all(defaultComplexArray[1i].vec == vec2f())); s = (s && all(defaultComplexArray[2i].vec == vec2f())); diff --git a/packages/typegpu/tests/indent.test.ts b/packages/typegpu/tests/indent.test.ts index dc9d190477..bc58844f4c 100644 --- a/packages/typegpu/tests/indent.test.ts +++ b/packages/typegpu/tests/indent.test.ts @@ -348,7 +348,7 @@ describe('indents', () => { }, })((input) => { const uniBoid = layout.$.boids; - for (let i = d.u32(); i < std.floor(std.sin(123)); i++) { + for (let i = d.u32(); i < std.floor(std.sin(Math.PI / 2)); i++) { const sampled = std.textureSample( layout.$.sampled, layout.$.sampler, @@ -404,7 +404,7 @@ describe('indents', () => { @vertex fn someVertex(input: someVertex_Input) -> someVertex_Output { let uniBoid = (&boids); - for (var i = 0u; (i < -1u); i++) { + for (var i = 0u; (i < 1u); i++) { var sampled_1 = textureSample(sampled, sampler_1, vec2f(0.5), i); var someVal = textureLoad(smoothRender, vec2i(), 0); if (((someVal.x + sampled_1.x) > 0.5f)) { diff --git a/packages/typegpu/tests/primitiveCast.test.ts b/packages/typegpu/tests/primitiveCast.test.ts index 5f37c4c694..a7be1168c4 100644 --- a/packages/typegpu/tests/primitiveCast.test.ts +++ b/packages/typegpu/tests/primitiveCast.test.ts @@ -6,7 +6,7 @@ describe('u32', () => { expect(u32(10)).toBe(10); expect(u32(10.5)).toBe(10); expect(u32(-10)).toBe(4294967286); - expect(u32(-10.5)).toBe(4294967286); + expect(u32(-10.5)).toBe(0); expect(u32(4294967295)).toBe(4294967295); expect(u32(4294967296)).toBe(0); expect(u32(4294967297)).toBe(1); diff --git a/packages/typegpu/tests/std/texture/textureGather.test.ts b/packages/typegpu/tests/std/texture/textureGather.test.ts new file mode 100644 index 0000000000..be53681322 --- /dev/null +++ b/packages/typegpu/tests/std/texture/textureGather.test.ts @@ -0,0 +1,109 @@ +/** biome-ignore-all lint/correctness/noConstantCondition: we are using it intentionally to prune type checks */ +import { describe, expect, expectTypeOf } from 'vitest'; +import { it } from '../../utils/extendedIt.ts'; +import { textureGather } from '../../../src/std/texture.ts'; +import tgpu from '../../../src/index.ts'; +import * as d from '../../../src/data/index.ts'; +import { bindGroupLayout } from '../../../src/tgpuBindGroupLayout.ts'; +import { resolve } from '../../../src/core/resolve/tgpuResolve.ts'; + +describe('textureGather', () => { + it('Has correct signatures', () => { + const testLayout = bindGroupLayout({ + tex2d: { texture: d.texture2d() }, + tex2d_u32: { texture: d.texture2d(d.u32) }, + tex2d_array: { texture: d.texture2dArray(d.i32) }, + texcube_array: { texture: d.textureCubeArray() }, + texdepth2d: { texture: d.textureDepth2d() }, + texdepth2d_array: { texture: d.textureDepth2dArray() }, + + sampler: { sampler: 'non-filtering' }, + }); + + expectTypeOf(() => {}); + + const testFn = tgpu.fn([])(() => { + const uv2d = d.vec2f(0.5, 0.5); + const uv3d = d.vec3f(0.5, 0.5, 0); + const idx = d.f32(1.2); // f32 to verify proper conversion (implicit in this case) + const component = d.i32(0); + + const gather2d = textureGather( + component, + testLayout.$.tex2d, + testLayout.$.sampler, + uv2d, + ); + const gather2d_u32 = textureGather( + component, + testLayout.$.tex2d_u32, + testLayout.$.sampler, + uv2d, + ); + const gather2d_array = textureGather( + component, + testLayout.$.tex2d_array, + testLayout.$.sampler, + uv2d, + idx, + ); + const gathercube_array = textureGather( + component, + testLayout.$.texcube_array, + testLayout.$.sampler, + uv3d, + idx, + ); + const gatherdepth2d = textureGather( + testLayout.$.texdepth2d, + testLayout.$.sampler, + uv2d, + ); + const gatherdepth2d_array = textureGather( + testLayout.$.texdepth2d_array, + testLayout.$.sampler, + uv2d, + idx, + ); + + if (false) { + expectTypeOf(gather2d).toEqualTypeOf(); + expectTypeOf(gather2d_u32).toEqualTypeOf(); + expectTypeOf(gather2d_array).toEqualTypeOf(); + expectTypeOf(gathercube_array).toEqualTypeOf(); + expectTypeOf(gatherdepth2d).toEqualTypeOf(); + expectTypeOf(gatherdepth2d_array).toEqualTypeOf(); + } + }); + + expect(resolve([testFn])).toMatchInlineSnapshot(` + "@group(0) @binding(0) var tex2d: texture_2d; + + @group(0) @binding(6) var sampler_1: sampler; + + @group(0) @binding(1) var tex2d_u32: texture_2d; + + @group(0) @binding(2) var tex2d_array: texture_2d_array; + + @group(0) @binding(3) var texcube_array: texture_cube_array; + + @group(0) @binding(4) var texdepth2d: texture_depth_2d; + + @group(0) @binding(5) var texdepth2d_array: texture_depth_2d_array; + + fn testFn() { + var uv2d = vec2f(0.5); + var uv3d = vec3f(0.5, 0.5, 0); + const idx = 1.2000000476837158f; + const component = 0i; + var gather2d = textureGather(component, tex2d, sampler_1, uv2d); + var gather2d_u32 = textureGather(component, tex2d_u32, sampler_1, uv2d); + var gather2d_array = textureGather(component, tex2d_array, sampler_1, uv2d, u32(idx)); + var gathercube_array = textureGather(component, texcube_array, sampler_1, uv3d, u32(idx)); + var gatherdepth2d = textureGather(texdepth2d, sampler_1, uv2d); + var gatherdepth2d_array = textureGather(texdepth2d_array, sampler_1, uv2d, u32(idx)); + + }" + `); + }); +}); diff --git a/packages/typegpu/tests/tgsl/typeInference.test.ts b/packages/typegpu/tests/tgsl/typeInference.test.ts index d426b04fee..e5ace66be9 100644 --- a/packages/typegpu/tests/tgsl/typeInference.test.ts +++ b/packages/typegpu/tests/tgsl/typeInference.test.ts @@ -209,7 +209,7 @@ describe('wgsl generator type inference', () => { expect(() => tgpu.resolve([add])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:add: Cannot convert value of type 'u32' to type 'void'] + - fn:add: Cannot convert value of type 'u32' to any of the target types: [void]] `); }); @@ -221,7 +221,7 @@ describe('wgsl generator type inference', () => { expect(() => tgpu.resolve([add])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:add: Cannot convert value of type 'abstractInt' to type 'vec3f'] + - fn:add: Cannot convert value of type 'abstractInt' to any of the target types: [vec3f]] `); }); @@ -269,7 +269,7 @@ describe('wgsl generator type inference', () => { expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:myFn: Cannot convert value of type 'vec2' to type 'bool'] + - fn:myFn: Cannot convert value of type 'vec2' to any of the target types: [bool]] `); }); @@ -284,7 +284,7 @@ describe('wgsl generator type inference', () => { expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:myFn: Cannot convert value of type 'mat2x2f' to type 'bool'] + - fn:myFn: Cannot convert value of type 'mat2x2f' to any of the target types: [bool]] `); }); @@ -301,7 +301,7 @@ describe('wgsl generator type inference', () => { expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:myFn: Cannot convert value of type 'abstractInt' to type 'bool'] + - fn:myFn: Cannot convert value of type 'abstractInt' to any of the target types: [bool]] `); }); diff --git a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts index 407cc525e7..d43f1fd6fe 100644 --- a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts +++ b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts @@ -869,7 +869,7 @@ describe('wgslGenerator', () => { expect(() => tgpu.resolve([testFn])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:testFn: Cannot convert value of type 'arrayOf(i32, 3)' to type 'vec2f'] + - fn:testFn: Cannot convert value of type 'arrayOf(i32, 3)' to any of the target types: [vec2f]] `); }); @@ -898,7 +898,7 @@ describe('wgslGenerator', () => { [Error: Resolution of the following tree failed: - - fn:testFn - - fn:vec4f: Cannot convert value of type 'arrayOf(i32, 4)' to type 'f32'] + - fn:vec4f: Cannot convert value of type 'arrayOf(i32, 4)' to any of the target types: [f32]] `); }); diff --git a/packages/typegpu/tests/vector.test.ts b/packages/typegpu/tests/vector.test.ts index ae65b47ef2..34a4e7cfd2 100644 --- a/packages/typegpu/tests/vector.test.ts +++ b/packages/typegpu/tests/vector.test.ts @@ -36,7 +36,7 @@ describe('setters', () => { vec[0] = 1.1; vec[1] = -1.1; vec.z = 2.2; - expect(vec).toStrictEqual(d.vec3u(1, 4294967295, 2)); + expect(vec).toStrictEqual(d.vec3u(1, 0, 2)); }); }); From 67c05d53a2238a7c3ac8f67ab1d32bdc38e59147 Mon Sep 17 00:00:00 2001 From: Konrad Reczko Date: Tue, 3 Feb 2026 12:39:42 +0100 Subject: [PATCH 2/3] review fixes --- packages/typegpu/src/std/numeric.ts | 94 +++++++++++-------------- packages/typegpu/src/tgsl/conversion.ts | 2 +- 2 files changed, 42 insertions(+), 54 deletions(-) diff --git a/packages/typegpu/src/std/numeric.ts b/packages/typegpu/src/std/numeric.ts index 6c87344182..0bd5512882 100644 --- a/packages/typegpu/src/std/numeric.ts +++ b/packages/typegpu/src/std/numeric.ts @@ -67,18 +67,6 @@ const unaryIdentitySignature = (arg: BaseData) => { }; }; -const unaryIdentityRestrictedSignature = - (restrict: BaseData[]) => (arg: BaseData) => { - const argRestricted = unify([arg], restrict); - if (!argRestricted) { - throw new SignatureNotSupportedError([arg], restrict); - } - return { - argTypes: argRestricted, - returnType: argRestricted[0] as BaseData, - }; - }; - const variadicUnifySignature = (...args: BaseData[]) => { const uargs = unify(args) ?? args; return ({ @@ -87,7 +75,7 @@ const variadicUnifySignature = (...args: BaseData[]) => { }); }; -const variadicUnifyRestrictedSignature = (restrict: BaseData[]) => +const unifyRestrictedSignature = (restrict: BaseData[]) => ( ...args: BaseData[] ) => { @@ -170,7 +158,7 @@ function cpuAcos(value: T): T { export const acos = dualImpl({ name: 'acos', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAcos, codegenImpl: (_ctx, [value]) => stitch`acos(${value})`, }); @@ -186,7 +174,7 @@ function cpuAcosh(value: T): T { export const acosh = dualImpl({ name: 'acosh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAcosh, codegenImpl: (_ctx, [value]) => stitch`acosh(${value})`, }); @@ -202,7 +190,7 @@ function cpuAsin(value: T): T { export const asin = dualImpl({ name: 'asin', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAsin, codegenImpl: (_ctx, [value]) => stitch`asin(${value})`, }); @@ -218,7 +206,7 @@ function cpuAsinh(value: T): T { export const asinh = dualImpl({ name: 'asinh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAsinh, codegenImpl: (_ctx, [value]) => stitch`asinh(${value})`, }); @@ -234,7 +222,7 @@ function cpuAtan(value: T): T { export const atan = dualImpl({ name: 'atan', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAtan, codegenImpl: (_ctx, [value]) => stitch`atan(${value})`, }); @@ -250,7 +238,7 @@ function cpuAtanh(value: T): T { export const atanh = dualImpl({ name: 'atanh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAtanh, codegenImpl: (_ctx, [value]) => stitch`atanh(${value})`, }); @@ -269,7 +257,7 @@ function cpuAtan2(y: T, x: T): T { export const atan2 = dualImpl({ name: 'atan2', - signature: variadicUnifyRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuAtan2, codegenImpl: (_ctx, [y, x]) => stitch`atan2(${y}, ${x})`, }); @@ -285,7 +273,7 @@ function cpuCeil(value: T): T { export const ceil = dualImpl({ name: 'ceil', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuCeil, codegenImpl: (_ctx, [value]) => stitch`ceil(${value})`, }); @@ -322,7 +310,7 @@ function cpuCos(value: T): T { export const cos = dualImpl({ name: 'cos', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuCos, codegenImpl: (_ctx, [value]) => stitch`cos(${value})`, }); @@ -338,7 +326,7 @@ function cpuCosh(value: T): T { export const cosh = dualImpl({ name: 'cosh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuCosh, codegenImpl: (_ctx, [value]) => stitch`cosh(${value})`, }); @@ -353,7 +341,7 @@ function cpuCountLeadingZeros( export const countLeadingZeros = dualImpl({ name: 'countLeadingZeros', - signature: unaryIdentityRestrictedSignature(anyConcreteInteger), + signature: unifyRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countLeadingZeros not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countLeadingZeros(${value})`, @@ -369,7 +357,7 @@ function cpuCountOneBits( export const countOneBits = dualImpl({ name: 'countOneBits', - signature: unaryIdentityRestrictedSignature(anyConcreteInteger), + signature: unifyRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countOneBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countOneBits(${value})`, @@ -385,7 +373,7 @@ function cpuCountTrailingZeros( export const countTrailingZeros = dualImpl({ name: 'countTrailingZeros', - signature: unaryIdentityRestrictedSignature(anyConcreteInteger), + signature: unifyRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for countTrailingZeros not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`countTrailingZeros(${value})`, @@ -393,7 +381,7 @@ export const countTrailingZeros = dualImpl({ export const cross = dualImpl({ name: 'cross', - signature: variadicUnifyRestrictedSignature([vec3f, vec3h]), + signature: unifyRestrictedSignature([vec3f, vec3h]), normalImpl: (a: T, b: T): T => VectorOps.cross[a.kind](a, b), codegenImpl: (_ctx, [a, b]) => stitch`cross(${a}, ${b})`, @@ -412,7 +400,7 @@ function cpuDegrees(value: T): T { export const degrees = dualImpl({ name: 'degrees', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuDegrees, codegenImpl: (_ctx, [value]) => stitch`degrees(${value})`, }); @@ -501,7 +489,7 @@ function cpuExp(value: T): T { export const exp = dualImpl({ name: 'exp', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuExp, codegenImpl: (_ctx, [value]) => stitch`exp(${value})`, }); @@ -517,7 +505,7 @@ function cpuExp2(value: T): T { export const exp2 = dualImpl({ name: 'exp2', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuExp2, codegenImpl: (_ctx, [value]) => stitch`exp2(${value})`, }); @@ -558,7 +546,7 @@ export const faceForward = dualImpl< (e1: T, e2: T, e3: T) => T >({ name: 'faceForward', - signature: variadicUnifyRestrictedSignature(anyFloatVec), + signature: unifyRestrictedSignature(anyFloatVec), normalImpl: 'CPU implementation for faceForward not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [e1, e2, e3]) => stitch`faceForward(${e1}, ${e2}, ${e3})`, @@ -590,7 +578,7 @@ function cpuFirstTrailingBit( export const firstTrailingBit = dualImpl({ name: 'firstTrailingBit', - signature: unaryIdentityRestrictedSignature(anyConcreteInteger), + signature: unifyRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for firstTrailingBit not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`firstTrailingBit(${value})`, @@ -607,7 +595,7 @@ function cpuFloor(value: T): T { export const floor = dualImpl({ name: 'floor', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuFloor, codegenImpl: (_ctx, [arg]) => stitch`floor(${arg})`, }); @@ -629,7 +617,7 @@ function cpuFma( export const fma = dualImpl({ name: 'fma', - signature: variadicUnifyRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuFma, codegenImpl: (_ctx, [e1, e2, e3]) => stitch`fma(${e1}, ${e2}, ${e3})`, }); @@ -645,7 +633,7 @@ function cpuFract(value: T): T { export const fract = dualImpl({ name: 'fract', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuFract, codegenImpl: (_ctx, [a]) => stitch`fract(${a})`, }); @@ -737,7 +725,7 @@ function cpuInverseSqrt(value: T): T { export const inverseSqrt = dualImpl({ name: 'inverseSqrt', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuInverseSqrt, codegenImpl: (_ctx, [value]) => stitch`inverseSqrt(${value})`, }); @@ -818,7 +806,7 @@ function cpuLog(value: T): T { export const log = dualImpl({ name: 'log', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuLog, codegenImpl: (_ctx, [value]) => stitch`log(${value})`, }); @@ -834,7 +822,7 @@ function cpuLog2(value: T): T { export const log2 = dualImpl({ name: 'log2', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuLog2, codegenImpl: (_ctx, [value]) => stitch`log2(${value})`, }); @@ -902,7 +890,7 @@ function cpuMix( export const mix = dualImpl({ name: 'mix', - signature: (...[e1, e2, e3]) => { + signature: (e1, e2, e3) => { if (e1.type.startsWith('vec') && !e3.type.startsWith('vec')) { const uarg = unify([e3], [(e1 as unknown as Vec2f).primitive]); if (!uarg) { @@ -970,7 +958,7 @@ export const modf: ModfOverload = dualImpl({ export const normalize = dualImpl({ name: 'normalize', - signature: unaryIdentityRestrictedSignature(anyFloatVec), + signature: unifyRestrictedSignature(anyFloatVec), normalImpl: (v: T): T => VectorOps.normalize[v.kind](v), codegenImpl: (_ctx, [value]) => stitch`normalize(${value})`, @@ -996,7 +984,7 @@ function powCpu( export const pow = dualImpl({ name: 'pow', - signature: variadicUnifyRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: powCpu, codegenImpl: (_ctx, [lhs, rhs]) => stitch`pow(${lhs}, ${rhs})`, }); @@ -1036,7 +1024,7 @@ function cpuRadians(value: T): T { export const radians = dualImpl({ name: 'radians', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuRadians, codegenImpl: (_ctx, [value]) => stitch`radians(${value})`, }); @@ -1082,7 +1070,7 @@ function cpuReverseBits(value: T): T { export const reverseBits = dualImpl({ name: 'reverseBits', - signature: unaryIdentityRestrictedSignature(anyConcreteInteger), + signature: unifyRestrictedSignature(anyConcreteInteger), normalImpl: 'CPU implementation for reverseBits not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`reverseBits(${value})`, @@ -1101,7 +1089,7 @@ function cpuRound(value: T): T { export const round = dualImpl({ name: 'round', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuRound, codegenImpl: (_ctx, [value]) => stitch`round(${value})`, }); @@ -1119,7 +1107,7 @@ function cpuSaturate(value: T): T { export const saturate = dualImpl({ name: 'saturate', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuSaturate, codegenImpl: (_ctx, [value]) => stitch`saturate(${value})`, }); @@ -1158,7 +1146,7 @@ function cpuSin(value: T): T { export const sin = dualImpl({ name: 'sin', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuSin, codegenImpl: (_ctx, [value]) => stitch`sin(${value})`, }); @@ -1176,7 +1164,7 @@ function cpuSinh(value: T): T { export const sinh = dualImpl({ name: 'sinh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuSinh, codegenImpl: (_ctx, [value]) => stitch`sinh(${value})`, }); @@ -1208,7 +1196,7 @@ function cpuSmoothstep( export const smoothstep = dualImpl({ name: 'smoothstep', - signature: variadicUnifyRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuSmoothstep, codegenImpl: (_ctx, [edge0, edge1, x]) => stitch`smoothstep(${edge0}, ${edge1}, ${x})`, @@ -1225,7 +1213,7 @@ function cpuSqrt(value: T): T { export const sqrt = dualImpl({ name: 'sqrt', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuSqrt, codegenImpl: (_ctx, [value]) => stitch`sqrt(${value})`, }); @@ -1243,7 +1231,7 @@ function cpuStep(edge: T, x: T): T { export const step = dualImpl({ name: 'step', - signature: variadicUnifyRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuStep, codegenImpl: (_ctx, [edge, x]) => stitch`step(${edge}, ${x})`, }); @@ -1261,7 +1249,7 @@ function cpuTan(value: T): T { export const tan = dualImpl({ name: 'tan', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuTan, codegenImpl: (_ctx, [value]) => stitch`tan(${value})`, }); @@ -1277,7 +1265,7 @@ function cpuTanh(value: T): T { export const tanh = dualImpl({ name: 'tanh', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: cpuTanh, codegenImpl: (_ctx, [value]) => stitch`tanh(${value})`, }); @@ -1298,7 +1286,7 @@ function cpuTrunc(value: T): T { export const trunc = dualImpl({ name: 'trunc', - signature: unaryIdentityRestrictedSignature(anyFloat), + signature: unifyRestrictedSignature(anyFloat), normalImpl: 'CPU implementation for trunc not implemented yet. Please submit an issue at https://github.com/software-mansion/TypeGPU/issues', codegenImpl: (_ctx, [value]) => stitch`trunc(${value})`, diff --git a/packages/typegpu/src/tgsl/conversion.ts b/packages/typegpu/src/tgsl/conversion.ts index 0b5e2f9664..f2480a5eb6 100644 --- a/packages/typegpu/src/tgsl/conversion.ts +++ b/packages/typegpu/src/tgsl/conversion.ts @@ -353,7 +353,7 @@ export function tryConvertSnippet( return snip(value, target, origin); } - if (typeof dataType === 'symbol') { + if (dataType === UnknownData) { // Commit unknown to the expected type. return snip(stitch`${snip(value, target, origin)}`, target, origin); } From e305f4fefa6c327d663be2096517acd5a1080b17 Mon Sep 17 00:00:00 2001 From: Konrad Reczko Date: Mon, 9 Feb 2026 13:44:50 +0100 Subject: [PATCH 3/3] review fixes --- packages/typegpu/src/core/function/createCallableSchema.ts | 2 +- packages/typegpu/src/data/numeric.ts | 2 +- packages/typegpu/src/std/array.ts | 2 +- packages/typegpu/src/std/numeric.ts | 7 +++---- packages/typegpu/src/std/operators.ts | 2 +- packages/typegpu/tests/array.test.ts | 2 +- .../tests/examples/individual/tgsl-parsing-test.test.ts | 4 ++-- 7 files changed, 10 insertions(+), 11 deletions(-) diff --git a/packages/typegpu/src/core/function/createCallableSchema.ts b/packages/typegpu/src/core/function/createCallableSchema.ts index 7266c7ac6e..3b3c438f38 100644 --- a/packages/typegpu/src/core/function/createCallableSchema.ts +++ b/packages/typegpu/src/core/function/createCallableSchema.ts @@ -9,9 +9,9 @@ import { NormalState, type ResolutionCtx, } from '../../types.ts'; +import type { AnyFn } from './fnTypes.ts'; type MapValueToDataType = { [K in keyof T]: BaseData }; -type AnyFn = (...args: never[]) => unknown; interface CallableSchemaOptions { readonly name: string; diff --git a/packages/typegpu/src/data/numeric.ts b/packages/typegpu/src/data/numeric.ts index dd74d292c6..bab2e69d01 100644 --- a/packages/typegpu/src/data/numeric.ts +++ b/packages/typegpu/src/data/numeric.ts @@ -162,7 +162,7 @@ export const i32: I32 = Object.assign(i32Cast, { const f32Cast = callableSchema({ name: 'f32', - signature: (arg) => ({ argTypes: [arg ? [arg] : []], returnType: f32 }), + signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f32 }), normalImpl(v?: number | boolean) { if (v === undefined) { return 0; diff --git a/packages/typegpu/src/std/array.ts b/packages/typegpu/src/std/array.ts index a257912582..c853bb7601 100644 --- a/packages/typegpu/src/std/array.ts +++ b/packages/typegpu/src/std/array.ts @@ -23,6 +23,6 @@ export const arrayLength = dualImpl({ isRef(a) ? a.$.length : a.length, codegenImpl(_ctx, [a]) { const length = sizeOfPointedToArray(a.dataType); - return length > 0 ? String(length) : stitch`arrayLength(${a})`; + return length > 0 ? `${length}u` : stitch`arrayLength(${a})`; }, }); diff --git a/packages/typegpu/src/std/numeric.ts b/packages/typegpu/src/std/numeric.ts index 0bd5512882..18d007380f 100644 --- a/packages/typegpu/src/std/numeric.ts +++ b/packages/typegpu/src/std/numeric.ts @@ -1,6 +1,5 @@ import { dualImpl, MissingCpuImplError } from '../core/function/dualImpl.ts'; import { stitch } from '../core/resolve/stitch.ts'; -import type { AnyData } from '../data/dataTypes.ts'; import { mat2x2f, mat3x3f, mat4x4f } from '../data/matrix.ts'; import { smoothstepScalar } from '../data/numberOps.ts'; import { @@ -703,7 +702,7 @@ export const insertBits = dualImpl({ } return { argTypes: [...uargs, u32, u32], - returnType: uargs[0] as AnyData, + returnType: uargs[0] as BaseData, }; }, normalImpl: @@ -898,13 +897,13 @@ export const mix = dualImpl({ (e1 as unknown as Vec2f).primitive, ]); } - return { argTypes: [e1, e2, uarg[0] as AnyData], returnType: e1 }; + return { argTypes: [e1, e2, uarg[0] as BaseData], returnType: e1 }; } const uargs = unify([e1, e2, e3], anyFloat); if (!uargs) { throw new SignatureNotSupportedError([e1, e2, e3], anyFloat); } - return { argTypes: uargs, returnType: uargs[0] as AnyData }; + return { argTypes: uargs, returnType: uargs[0] as BaseData }; }, normalImpl: cpuMix, codegenImpl: (_ctx, [e1, e2, e3]) => stitch`mix(${e1}, ${e2}, ${e3})`, diff --git a/packages/typegpu/src/std/operators.ts b/packages/typegpu/src/std/operators.ts index 50418023fc..7bc8c5a057 100644 --- a/packages/typegpu/src/std/operators.ts +++ b/packages/typegpu/src/std/operators.ts @@ -236,7 +236,7 @@ type ModOverload = { */ export const mod: ModOverload = dualImpl({ name: 'mod', - signature: binaryArithmeticSignature, + signature: binaryDivSignature, normalImpl(a: T, b: T): T { if (typeof a === 'number' && typeof b === 'number') { return (a % b) as T; // scalar % scalar diff --git a/packages/typegpu/tests/array.test.ts b/packages/typegpu/tests/array.test.ts index 7878524451..0fc4c7839f 100644 --- a/packages/typegpu/tests/array.test.ts +++ b/packages/typegpu/tests/array.test.ts @@ -446,7 +446,7 @@ describe('array.length', () => { expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` "fn testFn() -> i32 { - return 5; + return 5u; }" `); }); diff --git a/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts b/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts index 40a3b2632d..9d783d65a0 100644 --- a/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts +++ b/packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts @@ -124,11 +124,11 @@ describe('tgsl parsing test example', () => { fn arrayAndStructConstructorsTest() -> bool { var s = true; var defaultComplexStruct = ComplexStruct(); - s = (s && (2 == 2i)); + s = (s && (2u == 2i)); s = (s && (defaultComplexStruct.arr[0i] == 0i)); s = (s && (defaultComplexStruct.arr[1i] == 0i)); var defaultComplexArray = array(); - s = (s && (3 == 3i)); + s = (s && (3u == 3i)); s = (s && all(defaultComplexArray[0i].vec == vec2f())); s = (s && all(defaultComplexArray[1i].vec == vec2f())); s = (s && all(defaultComplexArray[2i].vec == vec2f()));