1+ /* Copyright (c) 2025, Sascha Willems
2+ *
3+ * SPDX-License-Identifier: MIT
4+ *
5+ */
6+
7+ struct VSInput
8+ {
9+ float3 Pos;
10+ float2 UV;
11+ float3 Normal;
12+ };
13+
14+ struct VSOutput
15+ {
16+ float4 Pos : SV_POSITION;
17+ float2 UV;
18+ float3 Normal;
19+ float3 ViewVec;
20+ float3 LightVec;
21+ };
22+
23+ struct UBO
24+ {
25+ float4x4 projection;
26+ float4x4 modelview;
27+ float4 lightPos;
28+ };
29+ [[vk::binding(0 ,0 )]] ConstantBuffer < UBO> ubo;
30+ [[vk::binding(1 ,0 )]] Sampler2D samplerColor;
31+
32+ struct Particle {
33+ float4 pos;
34+ float4 vel;
35+ float4 uv;
36+ float4 normal;
37+ };
38+
39+ [[vk::binding(0 ,0 )]] StructuredBuffer < Particle> particleIn;
40+ [[vk::binding(1 ,0 )]] RWStructuredBuffer < Particle> particleOut;
41+
42+ struct UBOCompute
43+ {
44+ float deltaT;
45+ float particleMass;
46+ float springStiffness;
47+ float damping;
48+ float restDistH;
49+ float restDistV;
50+ float restDistD;
51+ float sphereRadius;
52+ float4 spherePos;
53+ float4 gravity;
54+ int2 particleCount;
55+ };
56+ [[vk::binding(2 , 0 )]] ConstantBuffer < UBOCompute> params;
57+
58+ float3 springForce(float3 p0, float3 p1, float restDist)
59+ {
60+ float3 dist = p0 - p1;
61+ return normalize(dist) * params .springStiffness * (length(dist) - restDist);
62+ }
63+
64+ [shader(" vertex" )]
65+ VSOutput vertexMain(VSInput input)
66+ {
67+ VSOutput output;
68+ output .UV = input .UV ;
69+ output .Normal = input .Normal .xyz ;
70+ float4 eyePos = mul(ubo .modelview , float4(input .Pos .x , input .Pos .y , input .Pos .z , 1 . 0 ));
71+ output .Pos = mul(ubo .projection , eyePos);
72+ float4 pos = float4(input .Pos , 1 . 0 );
73+ float3 lPos = ubo .lightPos .xyz ;
74+ output .LightVec = lPos - pos .xyz ;
75+ output .ViewVec = - pos .xyz ;
76+ return output;
77+ }
78+
79+ [shader(" fragment" )]
80+ float4 fragmentMain(VSOutput input)
81+ {
82+ float3 color = samplerColor .Sample (input .UV ).rgb ;
83+ float3 N = normalize(input .Normal );
84+ float3 L = normalize(input .LightVec );
85+ float3 V = normalize(input .ViewVec );
86+ float3 R = reflect(- L, N);
87+ float3 diffuse = max(dot(N, L), 0 . 15 ) * float3(1 , 1 , 1 );
88+ float3 specular = pow(max(dot(R, V), 0 . 0 ), 8 . 0 ) * float3(0 . 2 , 0 . 2 , 0 . 2 );
89+ return float4(diffuse * color .rgb + specular, 1 . 0 );
90+ }
91+
92+ [shader(" compute" )]
93+ [numthreads(10 , 10 , 1 )]
94+ void computeMain(uint3 id: SV_DispatchThreadID, uniform uint calculateNormals)
95+ {
96+ uint index = id .y * params .particleCount .x + id .x ;
97+ if (index > params .particleCount .x * params .particleCount .y )
98+ return ;
99+
100+ // Initial force from gravity
101+ float3 force = params .gravity .xyz * params .particleMass ;
102+
103+ float3 pos = particleIn [index].pos .xyz ;
104+ float3 vel = particleIn [index].vel .xyz ;
105+
106+ // Spring forces from neighboring particles
107+ // left
108+ if (id .x > 0 ) {
109+ force += springForce(particleIn [index- 1 ].pos .xyz , pos, params .restDistH );
110+ }
111+ // right
112+ if (id .x < params .particleCount .x - 1 ) {
113+ force += springForce(particleIn [index + 1 ].pos .xyz , pos, params .restDistH );
114+ }
115+ // upper
116+ if (id .y < params .particleCount .y - 1 ) {
117+ force += springForce(particleIn [index + params .particleCount .x ].pos .xyz , pos, params .restDistV );
118+ }
119+ // lower
120+ if (id .y > 0 ) {
121+ force += springForce(particleIn [index - params .particleCount .x ].pos .xyz , pos, params .restDistV );
122+ }
123+ // upper-left
124+ if ((id .x > 0 ) && (id .y < params .particleCount .y - 1 )) {
125+ force += springForce(particleIn [index + params .particleCount .x - 1 ].pos .xyz , pos, params .restDistD );
126+ }
127+ // lower-left
128+ if ((id .x > 0 ) && (id .y > 0 )) {
129+ force += springForce(particleIn [index - params .particleCount .x - 1 ].pos .xyz , pos, params .restDistD );
130+ }
131+ // upper-right
132+ if ((id .x < params .particleCount .x - 1 ) && (id .y < params .particleCount .y - 1 )) {
133+ force += springForce(particleIn [index + params .particleCount .x + 1 ].pos .xyz , pos, params .restDistD );
134+ }
135+ // lower-right
136+ if ((id .x < params .particleCount .x - 1 ) && (id .y > 0 )) {
137+ force += springForce(particleIn [index - params .particleCount .x + 1 ].pos .xyz , pos, params .restDistD );
138+ }
139+
140+ force += (- params .damping * vel);
141+
142+ // Integrate
143+ float3 f = force * (1 . 0 / params .particleMass );
144+ particleOut [index].pos = float4(pos + vel * params .deltaT + 0 . 5 * f * params .deltaT * params .deltaT , 1 . 0 );
145+ particleOut [index].vel = float4(vel + f * params .deltaT , 0 . 0 );
146+
147+ // Sphere collision
148+ float3 sphereDist = particleOut [index].pos .xyz - params .spherePos .xyz ;
149+ if (length(sphereDist) < params .sphereRadius + 0 . 01 ) {
150+ // If the particle is inside the sphere, push it to the outer radius
151+ particleOut [index].pos .xyz = params .spherePos .xyz + normalize(sphereDist) * (params .sphereRadius + 0 . 01 );
152+ // Cancel out velocity
153+ particleOut [index].vel = float4(0 , 0 , 0 , 0 );
154+ }
155+
156+ // Normals
157+ if (calculateNormals == 1 ) {
158+ float3 normal = float3(0 , 0 , 0 );
159+ float3 a, b, c;
160+ if (id .y > 0 ) {
161+ if (id .x > 0 ) {
162+ a = particleIn [index - 1 ].pos .xyz - pos;
163+ b = particleIn [index - params .particleCount .x - 1 ].pos .xyz - pos;
164+ c = particleIn [index - params .particleCount .x ].pos .xyz - pos;
165+ normal += cross(a,b) + cross(b,c);
166+ }
167+ if (id .x < params .particleCount .x - 1 ) {
168+ a = particleIn [index - params .particleCount .x ].pos .xyz - pos;
169+ b = particleIn [index - params .particleCount .x + 1 ].pos .xyz - pos;
170+ c = particleIn [index + 1 ].pos .xyz - pos;
171+ normal += cross(a,b) + cross(b,c);
172+ }
173+ }
174+ if (id .y < params .particleCount .y - 1 ) {
175+ if (id .x > 0 ) {
176+ a = particleIn [index + params .particleCount .x ].pos .xyz - pos;
177+ b = particleIn [index + params .particleCount .x - 1 ].pos .xyz - pos;
178+ c = particleIn [index - 1 ].pos .xyz - pos;
179+ normal += cross(a,b) + cross(b,c);
180+ }
181+ if (id .x < params .particleCount .x - 1 ) {
182+ a = particleIn [index + 1 ].pos .xyz - pos;
183+ b = particleIn [index + params .particleCount .x + 1 ].pos .xyz - pos;
184+ c = particleIn [index + params .particleCount .x ].pos .xyz - pos;
185+ normal += cross(a,b) + cross(b,c);
186+ }
187+ }
188+ particleOut [index].normal = float4(normalize(normal), 0 . 0 f );
189+ }
190+ }
0 commit comments