Skip to content

Commit 7b88f68

Browse files
committed
Add slang shaders for gltf ray tracing sample
1 parent d6a1f6a commit 7b88f68

4 files changed

Lines changed: 263 additions & 4 deletions

File tree

shaders/slang/_rename.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ def checkRenameFiles(samplename):
1818
"raytracingbasic.rmiss.spv": "miss.rmiss.spv",
1919
"raytracingbasic.rgen.spv": "raygen.rgen.spv",
2020
}
21+
case "raytracinggltf":
22+
mappings = {
23+
"raytracinggltf.rchit.spv": "closesthit.rchit.spv",
24+
"raytracinggltf.rmiss.spv": "miss.rmiss.spv",
25+
"raytracinggltf.rgen.spv": "raygen.rgen.spv",
26+
"raytracinggltf.rahit.spv": "anyhit.rahit.spv",
27+
}
2128
case "raytracingreflections":
2229
mappings = {
2330
"raytracingreflections.rchit.spv": "closesthit.rchit.spv",

shaders/slang/compileshaders.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def getShaderStages(filename):
4747
stages.append("closesthit")
4848
if '[shader("callable")]' in filecontent:
4949
stages.append("callable")
50+
if '[shader("intersection")]' in filecontent:
51+
stages.append("intersection")
52+
if '[shader("anyhit")]' in filecontent:
53+
stages.append("anyhit")
5054
if '[shader("compute")]' in filecontent:
5155
stages.append("compute")
5256
if '[shader("amplification")]' in filecontent:
@@ -55,6 +59,10 @@ def getShaderStages(filename):
5559
stages.append("mesh")
5660
if '[shader("geometry")]' in filecontent:
5761
stages.append("geometry")
62+
if '[shader("hull")]' in filecontent:
63+
stages.append("hull")
64+
if '[shader("domain")]' in filecontent:
65+
stages.append("domain")
5866
f.close()
5967
return stages
6068

@@ -83,10 +91,7 @@ def getShaderStages(filename):
8391
print("Compiling %s" % input_file)
8492
output_base_file_name = input_file
8593
for stage in stages:
86-
if (len(stages) > 0):
87-
entry_point = stage + "Main"
88-
else:
89-
entry_point = "main"
94+
entry_point = stage + "Main"
9095
output_ext = ""
9196
match stage:
9297
case "vertex":
@@ -101,6 +106,10 @@ def getShaderStages(filename):
101106
output_ext = ".rchit"
102107
case "callable":
103108
output_ext = ".rcall"
109+
case "intersection":
110+
output_ext = ".rint"
111+
case "anyhit":
112+
output_ext = ".rahit"
104113
case "compute":
105114
output_ext = ".comp"
106115
case "mesh":
@@ -109,9 +118,15 @@ def getShaderStages(filename):
109118
output_ext = ".task"
110119
case "geometry":
111120
output_ext = ".geom"
121+
case "hull":
122+
output_ext = ".tesc"
123+
case "domain":
124+
output_ext = ".tese"
112125
output_file = output_base_file_name + output_ext + ".spv"
113126
output_file = output_file.replace(".slang", "")
127+
print(output_file)
114128
res = subprocess.call("%s %s -profile spirv_1_4 -matrix-layout-column-major -target spirv -o %s -entry %s -stage %s -warnings-disable 39001" % (compiler_path, input_file, output_file, entry_point, stage), shell=True)
115129
if res != 0:
130+
print("Error %s", res)
116131
sys.exit(res)
117132
checkRenameFiles(folder_name)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/* Copyright (c) 2025, Sascha Willems
2+
*
3+
* SPDX-License-Identifier: MIT
4+
*
5+
*/
6+
7+
struct Payload
8+
{
9+
float3 hitValue;
10+
uint payloadSeed;
11+
bool shadowed;
12+
};
13+
14+
struct GeometryNode {
15+
ConstBufferPointer<float4> vertices;
16+
ConstBufferPointer<uint> indices;
17+
int textureIndexBaseColor;
18+
int textureIndexOcclusion;
19+
};
20+
21+
struct UBOCameraProperties {
22+
float4x4 viewInverse;
23+
float4x4 projInverse;
24+
uint frame;
25+
}
26+
27+
[[vk::binding(0, 0)]] RaytracingAccelerationStructure accelStruct;
28+
[[vk::binding(1, 0)]] RWTexture2D<float4> image;
29+
[[vk::binding(2, 0)]] ConstantBuffer<UBOCameraProperties> cam;
30+
[[vk::binding(4, 0)]] StructuredBuffer<GeometryNode> geometryNodes;
31+
[[vk::binding(5, 0)]] Sampler2D textures[];
32+
33+
struct Vertex
34+
{
35+
float3 pos;
36+
float3 normal;
37+
float2 uv;
38+
};
39+
40+
struct Triangle {
41+
Vertex vertices[3];
42+
float3 normal;
43+
float2 uv;
44+
};
45+
46+
struct Attributes
47+
{
48+
float2 bary;
49+
};
50+
51+
// Tiny Encryption Algorithm
52+
// By Fahad Zafar, Marc Olano and Aaron Curtis, see https://www.highperformancegraphics.org/previous/www_2010/media/GPUAlgorithms/HPG2010_GPUAlgorithms_Zafar.pdf
53+
uint tea(uint val0, uint val1)
54+
{
55+
uint sum = 0;
56+
uint v0 = val0;
57+
uint v1 = val1;
58+
for (uint n = 0; n < 16; n++)
59+
{
60+
sum += 0x9E3779B9;
61+
v0 += ((v1 << 4) + 0xA341316C) ^ (v1 + sum) ^ ((v1 >> 5) + 0xC8013EA4);
62+
v1 += ((v0 << 4) + 0xAD90777D) ^ (v0 + sum) ^ ((v0 >> 5) + 0x7E95761E);
63+
}
64+
return v0;
65+
}
66+
67+
// Linear congruential generator based on the previous RNG state
68+
// See https://en.wikipedia.org/wiki/Linear_congruential_generator
69+
uint lcg(inout uint previous)
70+
{
71+
const uint multiplier = 1664525u;
72+
const uint increment = 1013904223u;
73+
previous = (multiplier * previous + increment);
74+
return previous & 0x00FFFFFF;
75+
}
76+
77+
// Generate a random float in [0, 1) given the previous RNG state
78+
float rnd(inout uint previous)
79+
{
80+
return (float(lcg(previous)) / float(0x01000000));
81+
}
82+
83+
// This function will unpack our vertex buffer data into a single triangle and calculates uv coordinates
84+
Triangle unpackTriangle(uint index, Attributes attribs) {
85+
Triangle tri;
86+
const uint triIndex = index * 3;
87+
const uint vertexsize = 112;
88+
89+
GeometryNode geometryNode = geometryNodes[GeometryIndex()];
90+
91+
// Indices indices = Indices(geometryNode.indexBufferDeviceAddress);
92+
// Vertices vertices = Vertices(geometryNode.vertexBufferDeviceAddress);
93+
94+
// Unpack vertices
95+
// Data is packed as float4 so we can map to the glTF vertex structure from the host side
96+
// We match vkglTF::Vertex: pos.xyz+normal.x, normalyz+uv.xy
97+
// glm::float3 pos;
98+
// glm::float3 normal;
99+
// glm::float2 uv;
100+
// ...
101+
for (uint i = 0; i < 3; i++) {
102+
const uint offset = geometryNode.indices[triIndex + i] * 6;
103+
float4 d0 = geometryNode.vertices[offset + 0]; // pos.xyz, n.x
104+
float4 d1 = geometryNode.vertices[offset + 1]; // n.yz, uv.xy
105+
tri.vertices[i].pos = d0.xyz;
106+
tri.vertices[i].normal = float3(d0.w, d1.xy);
107+
tri.vertices[i].uv = float2(d1.z, d1.w);
108+
}
109+
// Calculate values at barycentric coordinates
110+
float3 barycentricCoords = float3(1.0f - attribs.bary.x - attribs.bary.y, attribs.bary.x, attribs.bary.y);
111+
tri.uv = tri.vertices[0].uv * barycentricCoords.x + tri.vertices[1].uv * barycentricCoords.y + tri.vertices[2].uv * barycentricCoords.z;
112+
tri.normal = tri.vertices[0].normal * barycentricCoords.x + tri.vertices[1].normal * barycentricCoords.y + tri.vertices[2].normal * barycentricCoords.z;
113+
return tri;
114+
}
115+
116+
[shader("raygeneration")]
117+
void raygenerationMain()
118+
{
119+
uint3 LaunchID = DispatchRaysIndex();
120+
uint3 LaunchSize = DispatchRaysDimensions();
121+
122+
uint seed = tea(LaunchID.y * LaunchSize.x + LaunchID.x, cam.frame);
123+
124+
float r1 = rnd(seed);
125+
float r2 = rnd(seed);
126+
127+
// Subpixel jitter: send the ray through a different position inside the pixel
128+
// each time, to provide antialiasing.
129+
float2 subpixel_jitter = cam.frame == 0 ? float2(0.5f, 0.5f) : float2(r1, r2);
130+
const float2 pixelCenter = float2(LaunchID.xy) + subpixel_jitter;
131+
const float2 inUV = pixelCenter / float2(LaunchSize.xy);
132+
float2 d = inUV * 2.0 - 1.0;
133+
134+
float4 target = mul(cam.projInverse, float4(d.x, d.y, 1, 1));
135+
136+
RayDesc rayDesc;
137+
rayDesc.Origin = mul(cam.viewInverse, float4(0, 0, 0, 1)).xyz;
138+
rayDesc.Direction = mul(cam.viewInverse, float4(normalize(target.xyz), 0)).xyz;
139+
rayDesc.TMin = 0.001;
140+
rayDesc.TMax = 10000.0;
141+
142+
Payload payload;
143+
payload.hitValue = float3(0.0);
144+
float3 hitValues = float3(0);
145+
146+
const int samples = 4;
147+
148+
// Trace multiple rays for e.g. transparency
149+
for (int smpl = 0; smpl < samples; smpl++) {
150+
payload.payloadSeed = tea(LaunchID.y * LaunchSize.x + LaunchID.x, cam.frame);
151+
TraceRay(accelStruct, RAY_FLAG_NONE, 0xff, 0, 0, 0, rayDesc, payload);
152+
hitValues += payload.hitValue;
153+
}
154+
155+
float3 hitVal = hitValues / float(samples);
156+
157+
if (cam.frame > 0)
158+
{
159+
float a = 1.0f / float(cam.frame + 1);
160+
float3 old_color = image[int2(LaunchID.xy)].xyz;
161+
image[int2(LaunchID.xy)] = float4(lerp(old_color, hitVal, a), 1.0f);
162+
}
163+
else
164+
{
165+
// First frame, replace the value in the buffer
166+
image[int2(LaunchID.xy)] = float4(hitVal, 1.0f);
167+
}
168+
}
169+
170+
[shader("closesthit")]
171+
void closesthitMain(inout Payload payload, in Attributes attribs)
172+
{
173+
Triangle tri = unpackTriangle(PrimitiveIndex(), attribs);
174+
payload.hitValue = float3(tri.normal);
175+
176+
GeometryNode geometryNode = geometryNodes[GeometryIndex()];
177+
178+
float3 color = textures[NonUniformResourceIndex(geometryNode.textureIndexBaseColor)].SampleLevel(tri.uv, 0.0).rgb;
179+
if (geometryNode.textureIndexOcclusion > -1) {
180+
float occlusion = textures[NonUniformResourceIndex(geometryNode.textureIndexOcclusion)].SampleLevel(tri.uv, 0.0).r;
181+
color *= occlusion;
182+
}
183+
184+
payload.hitValue = color;
185+
186+
// Shadow casting
187+
float tmin = 0.001;
188+
float tmax = 10000.0;
189+
float epsilon = 0.001;
190+
float3 origin = WorldRayOrigin() + WorldRayDirection() * RayTCurrent() + tri.normal * epsilon;
191+
payload.shadowed = true;
192+
float3 lightVector = float3(-5.0, -2.5, -5.0);
193+
// Trace shadow ray and offset indices to match shadow hit/miss shader group indices
194+
// traceRayEXT(topLevelAS, gl_RayFlagsTerminateOnFirstHitEXT | gl_RayFlagsOpaqueEXT | gl_RayFlagsSkipClosestHitShaderEXT, 0xFF, 0, 0, 1, origin, tmin, lightVector, tmax, 2);
195+
// if (shadowed) {
196+
// hitValue *= 0.7;
197+
// }
198+
}
199+
200+
[shader("anyhit")]
201+
void anyhitMain(inout Payload payload, in Attributes attribs)
202+
{
203+
Triangle tri = unpackTriangle(PrimitiveIndex(), attribs);
204+
GeometryNode geometryNode = geometryNodes[GeometryIndex()];
205+
float4 color = textures[NonUniformResourceIndex(geometryNode.textureIndexBaseColor)].SampleLevel(tri.uv, 0.0);
206+
// If the alpha value of the texture at the current UV coordinates is below a given threshold, we'll ignore this intersection
207+
// That way ray traversal will be stopped and the miss shader will be invoked
208+
if (color.a < 0.9) {
209+
if (rnd(payload.payloadSeed) > color.a) {
210+
IgnoreHit();
211+
}
212+
}
213+
}
214+
215+
[shader("miss")]
216+
void missMain(inout Payload payload)
217+
{
218+
payload.hitValue = float3(1.0);
219+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/* Copyright (c) 2025, Sascha Willems
2+
*
3+
* SPDX-License-Identifier: MIT
4+
*
5+
*/
6+
7+
struct Payload
8+
{
9+
float3 hitValue;
10+
uint payloadSeed;
11+
bool shadowed;
12+
};
13+
14+
[shader("miss")]
15+
void missMain(inout Payload payload)
16+
{
17+
payload.shadowed = false;
18+
}

0 commit comments

Comments
 (0)