1+ /* Copyright (c) 2025, Sascha Willems
2+ *
3+ * SPDX-License-Identifier: MIT
4+ *
5+ */
6+
7+ struct Vertex
8+ {
9+ float3 pos;
10+ float2 uv;
11+ };
12+
13+ struct Triangle {
14+ Vertex vertices [3 ];
15+ float2 uv;
16+ };
17+
18+ struct BufferReferences {
19+ // Pointer to the buffer with the scene's MVP matrix
20+ ConstBufferPointer< float4 > vertices;
21+ // Pointer to the buffer for the data for each model
22+ ConstBufferPointer< uint > indices;
23+ };
24+ [[vk::push_constant]] BufferReferences bufferReferences;
25+
26+ struct Payload
27+ {
28+ float3 hitValue;
29+ };
30+
31+ struct UBOCameraProperties {
32+ float4x4 viewInverse;
33+ float4x4 projInverse;
34+ }
35+
36+ RaytracingAccelerationStructure accelStruct;
37+ RWTexture2D < float4 > image;
38+ ConstantBuffer < UBOCameraProperties> cam;
39+ Sampler2D samplerColor;
40+
41+ struct Attributes
42+ {
43+ float2 bary;
44+ };
45+
46+ Triangle unpackTriangle(uint index, Attributes attribs) {
47+ Triangle tri;
48+ const uint triIndex = index * 3 ;
49+ const uint vertexSize = 32 ;
50+
51+ // Unpack vertices
52+ // Data is packed as float4 so we can map to the glTF vertex structure from the host side
53+ for (uint i = 0 ; i < 3 ; i++ ) {
54+ const uint offset = bufferReferences .indices [triIndex + i] * (vertexSize / 16 );
55+ float4 d0 = bufferReferences .vertices [offset + 0 ]; // pos.xyz, n.x
56+ float4 d1 = bufferReferences .vertices [offset + 1 ]; // n.yz, uv.xy
57+ tri .vertices [i].pos = d0 .xyz ;
58+ tri .vertices [i].uv = d1 .zw ;
59+ }
60+ // Calculate values at barycentric coordinates
61+ float3 barycentricCoords = float3(1 . 0 f - attribs .bary .x - attribs .bary .y , attribs .bary .x , attribs .bary .y );
62+ tri .uv = tri .vertices [0 ].uv * barycentricCoords .x + tri .vertices [1 ].uv * barycentricCoords .y + tri .vertices [2 ].uv * barycentricCoords .z ;
63+ return tri;
64+ }
65+
66+ [shader(" raygeneration" )]
67+ void raygenerationMain()
68+ {
69+ uint3 LaunchID = DispatchRaysIndex();
70+ uint3 LaunchSize = DispatchRaysDimensions();
71+
72+ const float2 pixelCenter = float2(LaunchID .xy ) + float2(0 . 5 , 0 . 5 );
73+ const float2 inUV = pixelCenter / float2(LaunchSize .xy );
74+ float2 d = inUV * 2 . 0 - 1 . 0 ;
75+ float4 target = mul(cam .projInverse , float4(d .x , d .y , 1 , 1 ));
76+
77+ RayDesc rayDesc;
78+ rayDesc .Origin = mul(cam .viewInverse , float4(0 , 0 , 0 , 1 )).xyz ;
79+ rayDesc .Direction = mul(cam .viewInverse , float4(normalize(target .xyz ), 0 )).xyz ;
80+ rayDesc .TMin = 0 . 001 ;
81+ rayDesc .TMax = 10000 . 0 ;
82+
83+ Payload payload;
84+ TraceRay(accelStruct, RAY_FLAG_NONE, 0x ff , 0 , 0 , 0 , rayDesc, payload);
85+
86+ image [int2(LaunchID .xy )] = float4(payload .hitValue , 0 . 0 );
87+ }
88+
89+ [shader(" closesthit" )]
90+ void closesthitMain(inout Payload payload, in Attributes attribs)
91+ {
92+ Triangle tri = unpackTriangle(PrimitiveIndex(), attribs);
93+ // Fetch the color for this ray hit from the texture at the current uv coordinates
94+ float4 color = samplerColor .SampleLevel (tri .uv , 0 . 0 );
95+ payload .hitValue = color .rgb ;
96+ }
97+
98+ [shader(" anyhit" )]
99+ void anyhitMain(inout Payload payload, in Attributes attribs)
100+ {
101+ Triangle tri = unpackTriangle(PrimitiveIndex(), attribs);
102+ float4 color = samplerColor .SampleLevel (tri .uv , 0 . 0 );
103+ // If the alpha value of the texture at the current UV coordinates is below a given threshold, we'll ignore this intersection
104+ // That way ray traversal will be stopped and the miss shader will be invoked
105+ if (color .a < 0 . 9 ) {
106+ IgnoreHit();
107+ }
108+ }
109+
110+ [shader(" miss" )]
111+ void missMain(inout Payload payload)
112+ {
113+ payload .hitValue = float3(0 . 0 , 0 . 0 , 0 . 2 );
114+ }
0 commit comments