Skip to content

Commit e4b7bb1

Browse files
committed
Add slang shaders for compute cloth sample
1 parent c58fd18 commit e4b7bb1

2 files changed

Lines changed: 245 additions & 0 deletions

File tree

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/* Copyright (c) 2025, Sascha Willems
2+
*
3+
* SPDX-License-Identifier: MIT
4+
*
5+
*/
6+
7+
struct VSInput
8+
{
9+
float3 Pos;
10+
float2 UV;
11+
float3 Normal;
12+
};
13+
14+
struct VSOutput
15+
{
16+
float4 Pos : SV_POSITION;
17+
float2 UV;
18+
float3 Normal;
19+
float3 ViewVec;
20+
float3 LightVec;
21+
};
22+
23+
struct UBO
24+
{
25+
float4x4 projection;
26+
float4x4 modelview;
27+
float4 lightPos;
28+
};
29+
[[vk::binding(0,0)]] ConstantBuffer<UBO> ubo;
30+
[[vk::binding(1,0)]] Sampler2D samplerColor;
31+
32+
struct Particle {
33+
float4 pos;
34+
float4 vel;
35+
float4 uv;
36+
float4 normal;
37+
};
38+
39+
[[vk::binding(0,0)]] StructuredBuffer<Particle> particleIn;
40+
[[vk::binding(1,0)]] RWStructuredBuffer<Particle> particleOut;
41+
42+
struct UBOCompute
43+
{
44+
float deltaT;
45+
float particleMass;
46+
float springStiffness;
47+
float damping;
48+
float restDistH;
49+
float restDistV;
50+
float restDistD;
51+
float sphereRadius;
52+
float4 spherePos;
53+
float4 gravity;
54+
int2 particleCount;
55+
};
56+
[[vk::binding(2, 0)]] ConstantBuffer<UBOCompute> params;
57+
58+
float3 springForce(float3 p0, float3 p1, float restDist)
59+
{
60+
float3 dist = p0 - p1;
61+
return normalize(dist) * params.springStiffness * (length(dist) - restDist);
62+
}
63+
64+
[shader("vertex")]
65+
VSOutput vertexMain(VSInput input)
66+
{
67+
VSOutput output;
68+
output.UV = input.UV;
69+
output.Normal = input.Normal.xyz;
70+
float4 eyePos = mul(ubo.modelview, float4(input.Pos.x, input.Pos.y, input.Pos.z, 1.0));
71+
output.Pos = mul(ubo.projection, eyePos);
72+
float4 pos = float4(input.Pos, 1.0);
73+
float3 lPos = ubo.lightPos.xyz;
74+
output.LightVec = lPos - pos.xyz;
75+
output.ViewVec = -pos.xyz;
76+
return output;
77+
}
78+
79+
[shader("fragment")]
80+
float4 fragmentMain(VSOutput input)
81+
{
82+
float3 color = samplerColor.Sample(input.UV).rgb;
83+
float3 N = normalize(input.Normal);
84+
float3 L = normalize(input.LightVec);
85+
float3 V = normalize(input.ViewVec);
86+
float3 R = reflect(-L, N);
87+
float3 diffuse = max(dot(N, L), 0.15) * float3(1, 1, 1);
88+
float3 specular = pow(max(dot(R, V), 0.0), 8.0) * float3(0.2, 0.2, 0.2);
89+
return float4(diffuse * color.rgb + specular, 1.0);
90+
}
91+
92+
[shader("compute")]
93+
[numthreads(10, 10, 1)]
94+
void computeMain(uint3 id: SV_DispatchThreadID, uniform uint calculateNormals)
95+
{
96+
uint index = id.y * params.particleCount.x + id.x;
97+
if (index > params.particleCount.x * params.particleCount.y)
98+
return;
99+
100+
// Initial force from gravity
101+
float3 force = params.gravity.xyz * params.particleMass;
102+
103+
float3 pos = particleIn[index].pos.xyz;
104+
float3 vel = particleIn[index].vel.xyz;
105+
106+
// Spring forces from neighboring particles
107+
// left
108+
if (id.x > 0) {
109+
force += springForce(particleIn[index-1].pos.xyz, pos, params.restDistH);
110+
}
111+
// right
112+
if (id.x < params.particleCount.x - 1) {
113+
force += springForce(particleIn[index + 1].pos.xyz, pos, params.restDistH);
114+
}
115+
// upper
116+
if (id.y < params.particleCount.y - 1) {
117+
force += springForce(particleIn[index + params.particleCount.x].pos.xyz, pos, params.restDistV);
118+
}
119+
// lower
120+
if (id.y > 0) {
121+
force += springForce(particleIn[index - params.particleCount.x].pos.xyz, pos, params.restDistV);
122+
}
123+
// upper-left
124+
if ((id.x > 0) && (id.y < params.particleCount.y - 1)) {
125+
force += springForce(particleIn[index + params.particleCount.x - 1].pos.xyz, pos, params.restDistD);
126+
}
127+
// lower-left
128+
if ((id.x > 0) && (id.y > 0)) {
129+
force += springForce(particleIn[index - params.particleCount.x - 1].pos.xyz, pos, params.restDistD);
130+
}
131+
// upper-right
132+
if ((id.x < params.particleCount.x - 1) && (id.y < params.particleCount.y - 1)) {
133+
force += springForce(particleIn[index + params.particleCount.x + 1].pos.xyz, pos, params.restDistD);
134+
}
135+
// lower-right
136+
if ((id.x < params.particleCount.x - 1) && (id.y > 0)) {
137+
force += springForce(particleIn[index - params.particleCount.x + 1].pos.xyz, pos, params.restDistD);
138+
}
139+
140+
force += (-params.damping * vel);
141+
142+
// Integrate
143+
float3 f = force * (1.0 / params.particleMass);
144+
particleOut[index].pos = float4(pos + vel * params.deltaT + 0.5 * f * params.deltaT * params.deltaT, 1.0);
145+
particleOut[index].vel = float4(vel + f * params.deltaT, 0.0);
146+
147+
// Sphere collision
148+
float3 sphereDist = particleOut[index].pos.xyz - params.spherePos.xyz;
149+
if (length(sphereDist) < params.sphereRadius + 0.01) {
150+
// If the particle is inside the sphere, push it to the outer radius
151+
particleOut[index].pos.xyz = params.spherePos.xyz + normalize(sphereDist) * (params.sphereRadius + 0.01);
152+
// Cancel out velocity
153+
particleOut[index].vel = float4(0, 0, 0, 0);
154+
}
155+
156+
// Normals
157+
if (calculateNormals == 1) {
158+
float3 normal = float3(0, 0, 0);
159+
float3 a, b, c;
160+
if (id.y > 0) {
161+
if (id.x > 0) {
162+
a = particleIn[index - 1].pos.xyz - pos;
163+
b = particleIn[index - params.particleCount.x - 1].pos.xyz - pos;
164+
c = particleIn[index - params.particleCount.x].pos.xyz - pos;
165+
normal += cross(a,b) + cross(b,c);
166+
}
167+
if (id.x < params.particleCount.x - 1) {
168+
a = particleIn[index - params.particleCount.x].pos.xyz - pos;
169+
b = particleIn[index - params.particleCount.x + 1].pos.xyz - pos;
170+
c = particleIn[index + 1].pos.xyz - pos;
171+
normal += cross(a,b) + cross(b,c);
172+
}
173+
}
174+
if (id.y < params.particleCount.y - 1) {
175+
if (id.x > 0) {
176+
a = particleIn[index + params.particleCount.x].pos.xyz - pos;
177+
b = particleIn[index + params.particleCount.x - 1].pos.xyz - pos;
178+
c = particleIn[index - 1].pos.xyz - pos;
179+
normal += cross(a,b) + cross(b,c);
180+
}
181+
if (id.x < params.particleCount.x - 1) {
182+
a = particleIn[index + 1].pos.xyz - pos;
183+
b = particleIn[index + params.particleCount.x + 1].pos.xyz - pos;
184+
c = particleIn[index + params.particleCount.x].pos.xyz - pos;
185+
normal += cross(a,b) + cross(b,c);
186+
}
187+
}
188+
particleOut[index].normal = float4(normalize(normal), 0.0f);
189+
}
190+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2025, Sascha Willems
2+
*
3+
* SPDX-License-Identifier: MIT
4+
*
5+
*/
6+
7+
struct VSInput
8+
{
9+
float3 Pos;
10+
float2 UV;
11+
float3 Normal;
12+
};
13+
14+
struct VSOutput
15+
{
16+
float4 Pos : SV_POSITION;
17+
float3 Normal;
18+
float3 ViewVec;
19+
float3 LightVec;
20+
};
21+
22+
struct UBO
23+
{
24+
float4x4 projection;
25+
float4x4 modelview;
26+
float4 lightPos;
27+
};
28+
ConstantBuffer<UBO> ubo;
29+
30+
[shader("vertex")]
31+
VSOutput vertexMain(VSInput input)
32+
{
33+
VSOutput output;
34+
float4 eyePos = mul(ubo.modelview, float4(input.Pos.x, input.Pos.y, input.Pos.z, 1.0));
35+
output.Pos = mul(ubo.projection, eyePos);
36+
float4 pos = float4(input.Pos, 1.0);
37+
float3 lPos = ubo.lightPos.xyz;
38+
output.LightVec = lPos - pos.xyz;
39+
output.ViewVec = -pos.xyz;
40+
output.Normal = input.Normal;
41+
return output;
42+
}
43+
44+
[shader("fragment")]
45+
float4 fragmentMain(VSOutput input)
46+
{
47+
float3 color = float3(0.5, 0.5, 0.5);
48+
float3 N = normalize(input.Normal);
49+
float3 L = normalize(input.LightVec);
50+
float3 V = normalize(input.ViewVec);
51+
float3 R = reflect(-L, N);
52+
float3 diffuse = max(dot(N, L), 0.15);
53+
float3 specular = pow(max(dot(R, V), 0.0), 32.0);
54+
return float4(diffuse * color.rgb + specular, 1.0);
55+
}

0 commit comments

Comments
 (0)