1+ /* Copyright (c) 2025, Sascha Willems
2+ *
3+ * SPDX-License-Identifier: MIT
4+ *
5+ */
6+
7+ RaytracingAccelerationStructure accelStruct;
8+ RWTexture2D < float4 > image;
9+ struct CameraProperties
10+ {
11+ float4x4 viewInverse;
12+ float4x4 projInverse;
13+ float4 lightPos;
14+ int vertexSize;
15+ };
16+ ConstantBuffer < CameraProperties> ubo;
17+ StructuredBuffer < float4 > vertices;
18+ StructuredBuffer < uint > indices;
19+
20+ // Max. number of recursion is passed via a specialization constant
21+ [SpecializationConstant] const int MAX_RECURSION = 0 ;
22+
23+ struct Attributes
24+ {
25+ float2 bary;
26+ };
27+
28+ struct RayPayload
29+ {
30+ float3 color;
31+ float distance;
32+ float3 normal;
33+ float reflector;
34+ };
35+
36+ struct Vertex
37+ {
38+ float3 pos;
39+ float3 normal;
40+ float2 uv;
41+ float4 color;
42+ float4 _pad0;
43+ float4 _pad1;
44+ };
45+
46+ Vertex unpack(uint index)
47+ {
48+ // Unpack the vertices from the SSBO using the glTF vertex structure
49+ // The multiplier is the size of the vertex divided by four float components (=16 bytes)
50+ const int m = ubo .vertexSize / 16 ;
51+
52+ float4 d0 = vertices [m * index + 0 ];
53+ float4 d1 = vertices [m * index + 1 ];
54+ float4 d2 = vertices [m * index + 2 ];
55+
56+ Vertex v;
57+ v .pos = d0 .xyz ;
58+ v .normal = float3(d0 .w , d1 .x , d1 .y );
59+ v .color = float4(d2 .x , d2 .y , d2 .z , 1 . 0 );
60+
61+ return v;
62+ }
63+
64+ [shader(" raygeneration" )]
65+ void raygenerationMain()
66+ {
67+ uint3 LaunchID = DispatchRaysIndex();
68+ uint3 LaunchSize = DispatchRaysDimensions();
69+
70+ const float2 pixelCenter = float2(LaunchID .xy ) + float2(0 . 5 , 0 . 5 );
71+ const float2 inUV = pixelCenter / float2(LaunchSize .xy );
72+ float2 d = inUV * 2 . 0 - 1 . 0 ;
73+ float4 target = mul(ubo .projInverse , float4(d .x , d .y , 1 , 1 ));
74+
75+ RayDesc rayDesc;
76+ rayDesc .Origin = mul(ubo .viewInverse , float4(0 , 0 , 0 , 1 )).xyz ;
77+ rayDesc .Direction = mul(ubo .viewInverse , float4(normalize(target .xyz ), 0 )).xyz ;
78+ rayDesc .TMin = 0 . 001 ;
79+ rayDesc .TMax = 10000 . 0 ;
80+
81+ float3 color = float3(0 . 0 , 0 . 0 , 0 . 0 );
82+
83+ for (int i = 0 ; i < MAX_RECURSION; i++ ) {
84+ RayPayload rayPayload;
85+ TraceRay(accelStruct, RAY_FLAG_FORCE_OPAQUE, 0x ff , 0 , 0 , 0 , rayDesc, rayPayload);
86+ float3 hitColor = rayPayload .color ;
87+
88+ if (rayPayload .distance < 0 . 0 f ) {
89+ color += hitColor;
90+ break ;
91+ } else if (rayPayload .reflector == 1 . 0 f ) {
92+ const float3 hitPos = rayDesc .Origin + rayDesc .Direction * rayPayload .distance ;
93+ rayDesc .Origin = hitPos + rayPayload .normal * 0 . 001 f ;
94+ rayDesc .Direction = reflect(rayDesc .Direction , rayPayload .normal );
95+ } else {
96+ color += hitColor;
97+ break ;
98+ }
99+ }
100+
101+ image [int2(LaunchID .xy )] = float4(color, 0 . 0 );
102+ }
103+
104+ [shader(" closesthit" )]
105+ void closesthitMain(inout RayPayload rayPayload, in Attributes attribs)
106+ {
107+ uint PrimitiveID = PrimitiveIndex();
108+ int3 index = int3(indices [3 * PrimitiveID], indices [3 * PrimitiveID + 1 ], indices [3 * PrimitiveID + 2 ]);
109+
110+ Vertex v0 = unpack(index .x );
111+ Vertex v1 = unpack(index .y );
112+ Vertex v2 = unpack(index .z );
113+
114+ // Interpolate normal
115+ const float3 barycentricCoords = float3(1 . 0 f - attribs .bary .x - attribs .bary .y , attribs .bary .x , attribs .bary .y );
116+ float3 normal = normalize(v0 .normal * barycentricCoords .x + v1 .normal * barycentricCoords .y + v2 .normal * barycentricCoords .z );
117+
118+ // Basic lighting
119+ float3 lightVector = normalize(ubo .lightPos .xyz );
120+ float dot_product = max(dot(lightVector, normal), 0 . 6 );
121+ rayPayload .color .rgb = v0 .color .rgb * dot_product;
122+ rayPayload .distance = RayTCurrent();
123+ rayPayload .normal = normal;
124+
125+ // Objects with full white vertex color are treated as reflectors
126+ rayPayload .reflector = ((v0 .color .r == 1 . 0 f ) && (v0 .color .g == 1 . 0 f ) && (v0 .color .b == 1 . 0 f )) ? 1 . 0 f : 0 . 0 f ;
127+ }
128+
129+ [shader(" miss" )]
130+ void missMain(inout RayPayload rayPayload)
131+ {
132+ float3 worldRayDirection = WorldRayDirection();
133+
134+ // View-independent background gradient to simulate a basic sky background
135+ const float3 gradientStart = float3(0 . 5 , 0 . 6 , 1 . 0 );
136+ const float3 gradientEnd = float3(1 . 0 , 1 . 0 , 1 . 0 );
137+ float3 unitDir = normalize(worldRayDirection);
138+ float t = 0 . 5 * (unitDir .y + 1 . 0 );
139+ rayPayload .color = (1 . 0 - t) * gradientStart + t * gradientEnd;
140+
141+ rayPayload .distance = - 1 . 0 f ;
142+ rayPayload .normal = float3(0 , 0 , 0 );
143+ rayPayload .reflector = 0 . 0 f ;
144+ }
0 commit comments