Skip to content

Commit

Permalink
Change vertexFn parameters to accept a struct each. Handle TGSL ver…
Browse files Browse the repository at this point in the history
…tex shader resolution (#802)
  • Loading branch information
reczkok authored Feb 5, 2025
1 parent 9d5f6cf commit 4082c22
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ const vertexFunction = tgpu['~unstable']
{ vertexIndex: d.builtin.vertexIndex },
{ outPos: d.builtin.position },
)
.does(/* wgsl */ `(@builtin(vertex_index) vertexIndex: u32) -> VertexOutput {
.does(/* wgsl */ `(input: VertexInput) -> VertexOutput {
var pos = array<vec2f, 6>(
vec2<f32>( 1, 1),
vec2<f32>( 1, -1),
Expand All @@ -196,7 +196,7 @@ const vertexFunction = tgpu['~unstable']
);
var output: VertexOutput;
output.outPos = vec4f(pos[vertexIndex], 0, 1);
output.outPos = vec4f(pos[input.vertexIndex], 0, 1);
return output;
}`)
.$name('vertex_main');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ const VertexOutput = {

const mainVert = tgpu['~unstable']
.vertexFn({ v: d.vec2f, center: d.vec2f, velocity: d.vec2f }, VertexOutput)
.does(/* wgsl */ `(@location(0) v: vec2f, @location(1) center: vec2f, @location(2) velocity: vec2f) -> VertexOutput {
let angle = getRotationFromVelocity(velocity);
let rotated = rotate(v, angle);
.does(/* wgsl */ `(input: VertexInput) -> VertexOutput {
let angle = getRotationFromVelocity(input.velocity);
let rotated = rotate(input.v, angle);
let pos = vec4(rotated + center, 0.0, 1.0);
let pos = vec4(rotated + input.center, 0.0, 1.0);
let color = vec4(
sin(angle + colorPalette.r) * 0.45 + 0.45,
Expand Down Expand Up @@ -223,7 +223,7 @@ const mainCompute = tgpu['~unstable']
var alignmentCount = 0u;
var cohesion = vec2(0.0, 0.0);
var cohesionCount = 0u;
for (var i = 0u; i < arrayLength(&currentTrianglePos); i = i + 1) {
if (i == index) {
continue;
Expand Down Expand Up @@ -253,7 +253,7 @@ const mainCompute = tgpu['~unstable']
+ (alignment * params.alignmentStrength)
+ (cohesion * params.cohesionStrength);
instanceInfo.velocity = normalize(instanceInfo.velocity) * clamp(length(instanceInfo.velocity), 0.0, 0.01);
if (instanceInfo.position[0] > 1.0 + triangleSize) {
instanceInfo.position[0] = -1.0 - triangleSize;
}
Expand Down Expand Up @@ -340,7 +340,7 @@ export const controls = {
onButtonClick: () => paramsBuffer.write(presets.blobs),
},

'⚛ Particles': {
'⚛ Particles': {
onButtonClick: () => paramsBuffer.write(presets.particles),
},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,24 @@ const mainVert = tgpu['~unstable']
VertexOutput,
)
.does(
/* wgsl */ `(
@location(0) tilt: f32,
@location(1) angle: f32,
@location(2) color: vec4f,
@location(3) center: vec2f,
@builtin(vertex_index) index: u32,
) -> VertexOutput {
let width = tilt;
let height = tilt / 2;
/* wgsl */ `(input: VertexInput) -> VertexOutput {
let width = input.tilt;
let height = input.tilt / 2;
var pos = rotate(array<vec2f, 4>(
vec2f(0, 0),
vec2f(width, 0),
vec2f(0, height),
vec2f(width, height),
)[index] / 350, angle) + center;
)[input.index] / 350, input.angle) + input.center;
if (canvasAspectRatio < 1) {
pos.x /= canvasAspectRatio;
} else {
pos.y *= canvasAspectRatio;
}
return VertexOutput(vec4f(pos, 0.0, 1.0), color);
return VertexOutput(vec4f(pos, 0.0, 1.0), input.color);
}`,
)
.$uses({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ const vertexMain = tgpu['~unstable']
{ idx: d.builtin.vertexIndex },
{ pos: d.builtin.position, uv: d.vec2f },
)
.does(/* wgsl */ `(@builtin(vertex_index) idx: u32) -> VertexOut {
.does(/* wgsl */ `(input: VertexInput) -> VertexOut {
var pos = array<vec2f, 4>(
vec2(1, 1), // top-right
vec2(-1, 1), // top-left
Expand All @@ -506,8 +506,8 @@ const vertexMain = tgpu['~unstable']
);
var output: VertexOut;
output.pos = vec4f(pos[idx].x, pos[idx].y, 0.0, 1.0);
output.uv = uv[idx];
output.pos = vec4f(pos[input.idx].x, pos[input.idx].y, 0.0, 1.0);
output.uv = uv[input.idx];
return output;
}`);

Expand Down
8 changes: 6 additions & 2 deletions packages/typegpu/src/core/function/ioOutputType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export type WithLocations<T extends IORecord> = {
: Decorate<T[Key], Location<number>>;
};

export type IOLayoutToOutputSchema<T extends IOLayout> = T extends BaseWgslData
export type IOLayoutToSchema<T extends IOLayout> = T extends BaseWgslData
? Decorate<T, Location<0>>
: T extends IORecord
? TgpuStruct<WithLocations<T>>
Expand Down Expand Up @@ -56,5 +56,9 @@ export function createOutputType<T extends IOData>(returnType: IOLayout<T>) {
isData(returnType)
? location(0, returnType)
: struct(withLocations(returnType) as Record<string, T>)
) as IOLayoutToOutputSchema<IOLayout<T>>;
) as IOLayoutToSchema<IOLayout<T>>;
}

export function createStructFromIO<T extends IOData>(members: IORecord<T>) {
return struct(withLocations(members) as Record<string, T>);
}
10 changes: 5 additions & 5 deletions packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { OmitBuiltins } from '../../builtin';
import { type Vec4f, isWgslStruct } from '../../data/wgslTypes';
import type { TgpuNamable } from '../../namable';
import type { Vec4f } from '../../data/wgslTypes';
import { type TgpuNamable, isNamable } from '../../namable';
import type { Labelled, ResolutionCtx, SelfResolvable } from '../../types';
import { addReturnTypeToExternals } from '../resolve/externals';
import { createFnCore } from './fnCore';
Expand All @@ -11,7 +11,7 @@ import type {
Implementation,
InferIO,
} from './fnTypes';
import { type IOLayoutToOutputSchema, createOutputType } from './ioOutputType';
import { type IOLayoutToSchema, createOutputType } from './ioOutputType';

// ----------
// Public API
Expand Down Expand Up @@ -50,7 +50,7 @@ export interface TgpuFragmentFn<
Output extends IOLayout<Vec4f> = IOLayout<Vec4f>,
> extends TgpuNamable {
readonly shell: TgpuFragmentFnShell<Varying, Output>;
readonly outputType: IOLayoutToOutputSchema<Output>;
readonly outputType: IOLayoutToSchema<Output>;

$uses(dependencyMap: Record<string, unknown>): this;
}
Expand Down Expand Up @@ -120,7 +120,7 @@ function createFragmentFn(

$name(newLabel: string): This {
core.label = newLabel;
if (isWgslStruct(outputType)) {
if (isNamable(outputType)) {
outputType.$name(`${newLabel}_Output`);
}
return this;
Expand Down
51 changes: 40 additions & 11 deletions packages/typegpu/src/core/function/tgpuVertexFn.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { OmitBuiltins } from '../../builtin';
import { isWgslStruct } from '../../data/wgslTypes';
import type { TgpuNamable } from '../../namable';
import type { AnyWgslStruct } from '../../data/wgslTypes';
import { type TgpuNamable, isNamable } from '../../namable';
import type { GenerationCtx } from '../../smol/wgslGenerator';
import type { Labelled, ResolutionCtx, SelfResolvable } from '../../types';
import { addReturnTypeToExternals } from '../resolve/externals';
import { createFnCore } from './fnCore';
Expand All @@ -11,7 +12,11 @@ import type {
Implementation,
InferIO,
} from './fnTypes';
import { type IOLayoutToOutputSchema, createOutputType } from './ioOutputType';
import {
type IOLayoutToSchema,
createOutputType,
createStructFromIO,
} from './ioOutputType';

// ----------
// Public API
Expand All @@ -24,8 +29,9 @@ export interface TgpuVertexFnShell<
VertexIn extends IOLayout,
VertexOut extends IOLayout,
> {
readonly argTypes: [VertexIn];
readonly argTypes: [AnyWgslStruct];
readonly returnType: VertexOut;
readonly attributes: [VertexIn];

/**
* Creates a type-safe implementation of this signature
Expand All @@ -50,7 +56,8 @@ export interface TgpuVertexFn<
VertexOut extends IOLayout = IOLayout,
> extends TgpuNamable {
readonly shell: TgpuVertexFnShell<VertexIn, VertexOut>;
readonly outputType: IOLayoutToOutputSchema<VertexOut>;
readonly outputType: IOLayoutToSchema<VertexOut>;
readonly inputType: IOLayoutToSchema<VertexIn>;

$uses(dependencyMap: Record<string, unknown>): this;
}
Expand All @@ -67,7 +74,7 @@ export interface TgpuVertexFn<
* passed onto the fragment shader stage.
*/
export function vertexFn<
VertexIn extends IOLayout,
VertexIn extends IORecord,
// Not allowing single-value output, as it is better practice
// to properly label what the vertex shader is outputting.
VertexOut extends IORecord,
Expand All @@ -76,8 +83,9 @@ export function vertexFn<
outputType: VertexOut,
): TgpuVertexFnShell<ExoticIO<VertexIn>, ExoticIO<VertexOut>> {
return {
argTypes: [inputType as ExoticIO<VertexIn>],
returnType: outputType as ExoticIO<VertexOut>,
attributes: [inputType as ExoticIO<VertexIn>],
returnType: createOutputType(outputType) as ExoticIO<VertexOut>,
argTypes: [createStructFromIO(inputType)],

does(implementation) {
// biome-ignore lint/suspicious/noExplicitAny: <no thanks>
Expand All @@ -97,7 +105,8 @@ function createVertexFn(
type This = TgpuVertexFn<IOLayout, IOLayout> & Labelled & SelfResolvable;

const core = createFnCore(shell, implementation);
const outputType = createOutputType(shell.returnType);
const outputType = shell.returnType;
const inputType = shell.argTypes[0];
if (typeof implementation === 'string') {
addReturnTypeToExternals(implementation, outputType, (externals) =>
core.applyExternals(externals),
Expand All @@ -107,6 +116,7 @@ function createVertexFn(
return {
shell,
outputType,
inputType,

get label() {
return core.label;
Expand All @@ -119,14 +129,33 @@ function createVertexFn(

$name(newLabel: string): This {
core.label = newLabel;
if (isWgslStruct(outputType)) {
if (isNamable(outputType)) {
outputType.$name(`${newLabel}_Output`);
}
if (isNamable(inputType)) {
inputType.$name(`${newLabel}_Input`);
}
return this;
},

'~resolve'(ctx: ResolutionCtx): string {
return core.resolve(ctx, '@vertex ');
if (typeof implementation === 'string') {
return core.resolve(ctx, '@vertex ');
}

const generationCtx = ctx as GenerationCtx;
if (generationCtx.callStack === undefined) {
throw new Error(
'Cannot resolve a TGSL function outside of a generation context',
);
}

try {
generationCtx.callStack.push(outputType);
return core.resolve(ctx, '@vertex ');
} finally {
generationCtx.callStack.pop();
}
},

toString() {
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/core/pipeline/renderPipeline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class RenderPipelineCore {

constructor(public readonly options: RenderPipelineCoreOptions) {
const connectedAttribs = connectAttributesToShader(
options.vertexFn.shell.argTypes[0],
options.vertexFn.shell.attributes[0],
options.vertexAttribs,
);

Expand Down
13 changes: 13 additions & 0 deletions packages/typegpu/src/smol/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,19 @@ function generateStatement(
}

if ('r' in statement) {
// check if the thing at the top of the call stack is a struct
// if so wrap the value returned in a constructor of the struct (its resolved name)
if (
isWgslStruct(ctx.callStack[ctx.callStack.length - 1]) &&
statement.r !== null
) {
const resource = resolveRes(ctx, generateExpression(ctx, statement.r));
const resolvedStruct = ctx.resolve(
ctx.callStack[ctx.callStack.length - 1],
);
return `${ctx.pre}return ${resolvedStruct}(${resource});`;
}

return statement.r === null
? `${ctx.pre}return;`
: `${ctx.pre}return ${resolveRes(ctx, generateExpression(ctx, statement.r))};`;
Expand Down
6 changes: 3 additions & 3 deletions packages/typegpu/tests/ioOutputType.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, expect, expectTypeOf, it } from 'vitest';
import {
type IOLayoutToOutputSchema,
type IOLayoutToSchema,
withLocations,
} from '../src/core/function/ioOutputType';
import * as d from '../src/data';
Expand Down Expand Up @@ -36,7 +36,7 @@ describe('withLocations', () => {
describe('IOLayoutToOutputSchema', () => {
it('decorates types in a struct with location attribute for non-builtins and no custom locations', () => {
expectTypeOf<
IOLayoutToOutputSchema<{
IOLayoutToSchema<{
a: d.Decorated<d.Vec4f, [d.Location<5>]>;
b: d.Vec4f;
pos: d.BuiltinPosition;
Expand All @@ -51,7 +51,7 @@ describe('IOLayoutToOutputSchema', () => {
});

it('decorates non-struct types', () => {
expectTypeOf<IOLayoutToOutputSchema<d.Vec4f>>().toEqualTypeOf<
expectTypeOf<IOLayoutToSchema<d.Vec4f>>().toEqualTypeOf<
d.Decorated<d.Vec4f, [d.Location<0>]>
>();
});
Expand Down
4 changes: 2 additions & 2 deletions packages/typegpu/tests/rawFn.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ describe('tgpu.fn with raw string WGSL implementation', () => {
template: `
fn vs() {
out.highlighted = highlighted.index;
let h = highlighted;
let x = a.b.c.highlighted.d;
}
Expand Down Expand Up @@ -285,7 +285,7 @@ struct fragment_Output {
const func = tgpu['~unstable']
.fn([d.vec4f, Point], d.vec2f)
.does(/* wgsl */ `(
a: vec4f,
a: vec4f,
b : PointStruct ,
) -> vec2f {
var newPoint: PointStruct;
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/tests/renderPipeline.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ describe('Inter-Stage Variables', () => {
.does('() { layout.bound.alpha; }')
.$uses({ layout });

const fragmentFn = utgpu.vertexFn({}, { out: d.vec4f }).does('() {}');
const fragmentFn = utgpu.fragmentFn({}, { out: d.vec4f }).does('() {}');

const pipeline = root
.withVertex(vertexFn, {})
Expand Down
Loading

0 comments on commit 4082c22

Please sign in to comment.