Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions packages/typegpu/src/core/function/createCallableSchema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
import { type BaseData, isPtr } from '../../data/wgslTypes.ts';
import { setName } from '../../shared/meta.ts';
import { $gpuCallable } from '../../shared/symbols.ts';
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
import {
type DualFn,
isKnownAtComptime,
NormalState,
type ResolutionCtx,
} from '../../types.ts';
import type { AnyFn } from './fnTypes.ts';

type MapValueToDataType<T> = { [K in keyof T]: BaseData };

interface CallableSchemaOptions<T extends AnyFn> {
readonly name: string;
readonly normalImpl: T;
readonly codegenImpl: (
ctx: ResolutionCtx,
args: MapValueToSnippet<Parameters<T>>,
) => string;
readonly signature: (
...inArgTypes: MapValueToDataType<Parameters<T>>
) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData };
}

export function callableSchema<T extends AnyFn>(
options: CallableSchemaOptions<T>,
): DualFn<T> {
const impl = ((...args: Parameters<T>) => {
return options.normalImpl(...args);
}) as DualFn<T>;

setName(impl, options.name);
impl.toString = () => options.name;
impl[$gpuCallable] = {
get strictSignature() {
return undefined;
},
call(ctx, args) {
const { argTypes, returnType } = options.signature(
...args.map((s) => {
// Dereference implicit pointers
if (isPtr(s.dataType) && s.dataType.implicit) {
return s.dataType.inner;
}
return s.dataType;
}) as MapValueToDataType<Parameters<T>>,
);

const converted = args.map((s, idx) => {
const argType = argTypes[idx];
if (!argType) {
throw new Error('Function called with invalid arguments');
}
return tryConvertSnippet(
ctx,
s,
argType,
false,
);
}) as MapValueToSnippet<Parameters<T>>;

if (converted.every((s) => isKnownAtComptime(s))) {
ctx.pushMode(new NormalState());
try {
return snip(
options.normalImpl(...converted.map((s) => s.value) as never[]),
returnType,
// Functions give up ownership of their return value
/* origin */ 'constant',
);
} finally {
ctx.popMode('normal');
}
}

return snip(
options.codegenImpl(ctx, converted),
returnType,
// Functions give up ownership of their return value
/* origin */ 'runtime',
);
},
};

return impl;
}
10 changes: 7 additions & 3 deletions packages/typegpu/src/core/function/dualImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
import { setName } from '../../shared/meta.ts';
import { $gpuCallable } from '../../shared/symbols.ts';
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
import { concretize } from '../../tgsl/generationHelpers.ts';
import {
type DualFn,
isKnownAtComptime,
Expand All @@ -21,10 +22,13 @@ interface DualImplOptions<T extends AnyFn> {
args: MapValueToSnippet<Parameters<T>>,
) => string;
readonly signature:
| { argTypes: BaseData[]; returnType: BaseData }
| {
argTypes: (BaseData | BaseData[])[];
returnType: BaseData;
}
| ((
...inArgTypes: MapValueToDataType<Parameters<T>>
) => { argTypes: BaseData[]; returnType: BaseData });
) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData });
/**
* Whether the function should skip trying to execute the "normal" implementation if
* all arguments are known at compile time.
Expand Down Expand Up @@ -112,7 +116,7 @@ export function dualImpl<T extends AnyFn>(

return snip(
options.codegenImpl(ctx, converted),
returnType,
concretize(returnType),
// Functions give up ownership of their return value
/* origin */ 'runtime',
);
Expand Down
4 changes: 2 additions & 2 deletions packages/typegpu/src/data/matrix.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { comptime } from '../core/function/comptime.ts';
import { callableSchema } from '../core/function/createCallableSchema.ts';
import { dualImpl } from '../core/function/dualImpl.ts';
import { stitch } from '../core/resolve/stitch.ts';
import { $repr } from '../shared/symbols.ts';
Expand Down Expand Up @@ -64,7 +65,7 @@ function createMatSchema<
>(
options: MatSchemaOptions<TType, ColumnType>,
): { type: TType; [$repr]: ValueType } & MatConstructor<ValueType, ColumnType> {
const construct = dualImpl({
const construct = callableSchema({
name: options.type,
normalImpl: (...args: (number | ColumnType)[]): ValueType => {
const elements: number[] = [];
Expand Down Expand Up @@ -94,7 +95,6 @@ function createMatSchema<

return new options.MatImpl(...elements) as ValueType;
},
ignoreImplicitCastWarning: true,
signature: (...args) => ({
argTypes: args.map((arg) => (isVec(arg) ? arg : f32)),
returnType: schema as unknown as BaseData,
Expand Down
23 changes: 17 additions & 6 deletions packages/typegpu/src/data/numeric.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { stitch } from '../core/resolve/stitch.ts';
import { dualImpl } from '../core/function/dualImpl.ts';
import { $internal } from '../shared/symbols.ts';
import type {
AbstractFloat,
Expand All @@ -11,6 +10,7 @@ import type {
U16,
U32,
} from './wgslTypes.ts';
import { callableSchema } from '../core/function/createCallableSchema.ts';

export const abstractInt = {
[$internal]: {},
Expand All @@ -28,7 +28,7 @@ export const abstractFloat = {
},
} as AbstractFloat;

const boolCast = dualImpl({
const boolCast = callableSchema({
name: 'bool',
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: bool }),
normalImpl(v?: number | boolean) {
Expand Down Expand Up @@ -66,7 +66,7 @@ export const bool: Bool = Object.assign(boolCast, {
type: 'bool',
}) as unknown as Bool;

const u32Cast = dualImpl({
const u32Cast = callableSchema({
name: 'u32',
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: u32 }),
normalImpl(v?: number | boolean) {
Expand All @@ -76,6 +76,17 @@ const u32Cast = dualImpl({
if (typeof v === 'boolean') {
return v ? 1 : 0;
}
if (!Number.isInteger(v)) {
const truncated = Math.trunc(v);
if (truncated < 0) {
return 0;
}
if (truncated > 0xffffffff) {
return 0xffffffff;
}
return truncated;
}
// Integer input: treat as bit reinterpretation (i32 -> u32)
return (v & 0xffffffff) >>> 0;
},
codegenImpl: (_ctx, [arg]) =>
Expand Down Expand Up @@ -106,7 +117,7 @@ export const u32: U32 = Object.assign(u32Cast, {
type: 'u32',
}) as unknown as U32;

const i32Cast = dualImpl({
const i32Cast = callableSchema({
name: 'i32',
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: i32 }),
normalImpl(v?: number | boolean) {
Expand Down Expand Up @@ -149,7 +160,7 @@ export const i32: I32 = Object.assign(i32Cast, {
type: 'i32',
}) as unknown as I32;

const f32Cast = dualImpl({
const f32Cast = callableSchema({
name: 'f32',
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f32 }),
normalImpl(v?: number | boolean) {
Expand Down Expand Up @@ -275,7 +286,7 @@ function roundToF16(x: number): number {
return fromHalfBits(toHalfBits(x));
}

const f16Cast = dualImpl({
const f16Cast = callableSchema({
name: 'f16',
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f16 }),
normalImpl(v?: number | boolean) {
Expand Down
5 changes: 2 additions & 3 deletions packages/typegpu/src/data/vector.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { dualImpl } from '../core/function/dualImpl.ts';
import { callableSchema } from '../core/function/createCallableSchema.ts';
import { stitch } from '../core/resolve/stitch.ts';
import { $internal, $repr } from '../shared/symbols.ts';
import { bool, f16, f32, i32, u32 } from './numeric.ts';
Expand Down Expand Up @@ -307,14 +307,13 @@ function makeVecSchema<TValue, S extends number | boolean>(
);
};

const construct = dualImpl({
const construct = callableSchema({
name: type,
signature: (...args) => ({
argTypes: args.map((arg) => isVec(arg) ? arg : primitive),
returnType: schema as unknown as BaseData,
}),
normalImpl: cpuConstruct,
ignoreImplicitCastWarning: true,
codegenImpl: (_ctx, args) => {
if (
args.length === 1 && args[0]?.dataType === schema as unknown as BaseData
Expand Down
17 changes: 17 additions & 0 deletions packages/typegpu/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,20 @@ export class WgslTypeError extends Error {
Object.setPrototypeOf(this, WgslTypeError.prototype);
}
}

export class SignatureNotSupportedError extends Error {
constructor(actual: BaseData[], candidates: BaseData[]) {
super(
`Unsupported data types: ${
actual.map((a) => a.type).join(', ')
}. Supported types are: ${
candidates
.map((r) => r.type)
.join(', ')
}.`,
);

// Set the prototype explicitly.
Object.setPrototypeOf(this, SignatureNotSupportedError.prototype);
}
}
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ export const arrayLength = dualImpl({
isRef(a) ? a.$.length : a.length,
codegenImpl(_ctx, [a]) {
const length = sizeOfPointedToArray(a.dataType);
return length > 0 ? String(length) : stitch`arrayLength(${a})`;
return length > 0 ? `${length}u` : stitch`arrayLength(${a})`;
},
});
1 change: 1 addition & 0 deletions packages/typegpu/src/std/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ export {

export {
textureDimensions,
textureGather,
textureLoad,
textureSample,
textureSampleBaseClampToEdge,
Expand Down
Loading
Loading