Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change vertexFn parameters to accept a struct each. Handle TGSL vertex shader resolution #802

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions packages/typegpu/src/core/function/ioOutputType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ export function createOutputType<T extends IOData>(returnType: IOLayout<T>) {
: struct(withLocations(returnType) as Record<string, T>)
) as IOLayoutToOutputSchema<IOLayout<T>>;
}

export function createStructFromIO<T extends IOData>(members: IORecord<T>) {
return struct(withLocations(members) as Record<string, T>);
}
6 changes: 3 additions & 3 deletions packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { OmitBuiltins } from '../../builtin';
import { type Vec4f, isWgslStruct } from '../../data/wgslTypes';
import type { TgpuNamable } from '../../namable';
import type { Vec4f } from '../../data/wgslTypes';
import { type TgpuNamable, isNamable } from '../../namable';
import type { Labelled, ResolutionCtx, SelfResolvable } from '../../types';
import { addReturnTypeToExternals } from '../resolve/externals';
import { createFnCore } from './fnCore';
Expand Down Expand Up @@ -120,7 +120,7 @@ function createFragmentFn(

$name(newLabel: string): This {
core.label = newLabel;
if (isWgslStruct(outputType)) {
if (isNamable(outputType)) {
outputType.$name(`${newLabel}_Output`);
}
return this;
Expand Down
44 changes: 33 additions & 11 deletions packages/typegpu/src/core/function/tgpuVertexFn.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { OmitBuiltins } from '../../builtin';
import { isWgslStruct } from '../../data/wgslTypes';
import type { TgpuNamable } from '../../namable';
import type { AnyWgslStruct } from '../../data/wgslTypes';
import { type TgpuNamable, isNamable } from '../../namable';
import type { GenerationCtx } from '../../smol/wgslGenerator';
import type { Labelled, ResolutionCtx, SelfResolvable } from '../../types';
import { addReturnTypeToExternals } from '../resolve/externals';
import { createFnCore } from './fnCore';
Expand All @@ -11,7 +12,11 @@ import type {
Implementation,
InferIO,
} from './fnTypes';
import { type IOLayoutToOutputSchema, createOutputType } from './ioOutputType';
import {
type IOLayoutToOutputSchema,
createOutputType,
createStructFromIO,
} from './ioOutputType';

// ----------
// Public API
Expand All @@ -24,8 +29,9 @@ export interface TgpuVertexFnShell<
VertexIn extends IOLayout,
VertexOut extends IOLayout,
> {
readonly argTypes: [VertexIn];
readonly argTypes: [AnyWgslStruct];
readonly returnType: VertexOut;
readonly attributes: [VertexIn];

/**
* Creates a type-safe implementation of this signature
Expand Down Expand Up @@ -67,7 +73,7 @@ export interface TgpuVertexFn<
* passed onto the fragment shader stage.
*/
export function vertexFn<
VertexIn extends IOLayout,
VertexIn extends IORecord,
// Not allowing single-value output, as it is better practice
// to properly label what the vertex shader is outputting.
VertexOut extends IORecord,
Expand All @@ -76,8 +82,9 @@ export function vertexFn<
outputType: VertexOut,
): TgpuVertexFnShell<ExoticIO<VertexIn>, ExoticIO<VertexOut>> {
return {
argTypes: [inputType as ExoticIO<VertexIn>],
returnType: outputType as ExoticIO<VertexOut>,
attributes: [inputType as ExoticIO<VertexIn>],
returnType: createOutputType(outputType) as ExoticIO<VertexOut>,
argTypes: [createStructFromIO(inputType).$name('VertexInput')],

does(implementation) {
// biome-ignore lint/suspicious/noExplicitAny: <no thanks>
Expand All @@ -97,9 +104,9 @@ function createVertexFn(
type This = TgpuVertexFn<IOLayout, IOLayout> & Labelled & SelfResolvable;

const core = createFnCore(shell, implementation);
const outputType = createOutputType(shell.returnType);
const outputType = shell.returnType;
if (typeof implementation === 'string') {
addReturnTypeToExternals(implementation, outputType, (externals) =>
addReturnTypeToExternals(implementation, shell.returnType, (externals) =>
core.applyExternals(externals),
);
}
Expand All @@ -119,14 +126,29 @@ function createVertexFn(

$name(newLabel: string): This {
core.label = newLabel;
if (isWgslStruct(outputType)) {
if (isNamable(outputType)) {
outputType.$name(`${newLabel}_Output`);
}
return this;
},

'~resolve'(ctx: ResolutionCtx): string {
return core.resolve(ctx, '@vertex ');
if (typeof implementation === 'string') {
return core.resolve(ctx, '@vertex ');
}

const generationCtx = ctx as GenerationCtx;
if (generationCtx.callStack === undefined) {
throw new Error(
'Cannot resolve a TGSL function outside of a generation context',
);
}

generationCtx.callStack.push(outputType);
const resolved = core.resolve(ctx, '@vertex ');
generationCtx.callStack.pop();
reczkok marked this conversation as resolved.
Show resolved Hide resolved

return resolved;
Comment on lines +147 to +151
Copy link
Collaborator Author

@reczkok reczkok Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know which one I prefer 🤔 @mhawryluk @iwoplaza?

Suggested change
generationCtx.callStack.push(outputType);
const resolved = core.resolve(ctx, '@vertex ');
generationCtx.callStack.pop();
return resolved;
try {
generationCtx.callStack.push(outputType);
return core.resolve(ctx, '@vertex ');
} finally {
generationCtx.callStack.pop();
}

},

toString() {
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/core/pipeline/renderPipeline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class RenderPipelineCore {

constructor(public readonly options: RenderPipelineCoreOptions) {
const connectedAttribs = connectAttributesToShader(
options.vertexFn.shell.argTypes[0],
options.vertexFn.shell.attributes[0],
options.vertexAttribs,
);

Expand Down
9 changes: 9 additions & 0 deletions packages/typegpu/src/smol/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ function generateStatement(
}

if ('r' in statement) {
// check if the thing at the top of the call stack is a struct
// if so wrap the value returned in a constructor of the struct (its resolved name)
if (
isWgslStruct(ctx.callStack[ctx.callStack.length - 1]) &&
statement.r !== null
) {
return `${ctx.pre}return ${ctx.resolve(ctx.callStack[ctx.callStack.length - 1])}(${resolveRes(ctx, generateExpression(ctx, statement.r))});`;
}

return statement.r === null
? `${ctx.pre}return;`
: `${ctx.pre}return ${resolveRes(ctx, generateExpression(ctx, statement.r))};`;
Expand Down
4 changes: 2 additions & 2 deletions packages/typegpu/tests/rawFn.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ describe('tgpu.fn with raw string WGSL implementation', () => {
template: `
fn vs() {
out.highlighted = highlighted.index;

let h = highlighted;
let x = a.b.c.highlighted.d;
}
Expand Down Expand Up @@ -285,7 +285,7 @@ struct fragment_Output {
const func = tgpu['~unstable']
.fn([d.vec4f, Point], d.vec2f)
.does(/* wgsl */ `(
a: vec4f,
a: vec4f,
b : PointStruct ,
) -> vec2f {
var newPoint: PointStruct;
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/tests/renderPipeline.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ describe('Inter-Stage Variables', () => {
.does('() { layout.bound.alpha; }')
.$uses({ layout });

const fragmentFn = utgpu.vertexFn({}, { out: d.vec4f }).does('() {}');
const fragmentFn = utgpu.fragmentFn({}, { out: d.vec4f }).does('() {}');

const pipeline = root
.withVertex(vertexFn, {})
Expand Down
52 changes: 51 additions & 1 deletion packages/typegpu/tests/tgslFn.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { parse } from 'tgpu-wgsl-parser';
import { describe, expect, it } from 'vitest';
import tgpu from '../src';
import { f32, struct, vec3f } from '../src/data';
import { builtin } from '../src/builtin';
import { f32, struct, vec2f, vec3f, vec4f } from '../src/data';
import { parseResolved } from './utils/parseResolved';

describe('TGSL tgpu.fn function', () => {
Expand Down Expand Up @@ -162,4 +163,53 @@ describe('TGSL tgpu.fn function', () => {

expect(actual).toEqual(expected);
});

it('resolves vertexFn', () => {
const vertexFn = tgpu['~unstable']
.vertexFn(
{
vi: builtin.vertexIndex,
ii: builtin.instanceIndex,
color: vec4f,
},
{
pos: builtin.position,
uv: vec2f,
},
)
.does((input) => {
const vi = input.vi;
const ii = input.ii;
const color = input.color;

return {
pos: vec4f(color.w, ii, vi, 1),
uv: vec2f(color.w, vi),
};
})
.$name('vertex_fn');

const actual = parseResolved({ vertexFn });

const expected = parse(`
struct vertex_fn_Output {
@builtin(position) pos: vec4f,
@location(0) uv: vec2f,
}
struct VertexInput {
@builtin(vertex_index) vi: u32,
@builtin(instance_index) ii: u32,
@location(0) color: vec4f,
}

@vertex fn vertex_fn(input: VertexInput) -> vertex_fn_Output{
var vi = input.vi;
var ii = input.ii;
var color = input.color;
return vertex_fn_Output(vec4f(color.w, ii, vi, 1), vec2f(color.w, vi));
}
`);

expect(actual).toEqual(expected);
});
});
Loading