Skip to content

Commit aa22a94

Browse files
committed
asm fn_ptr: allow passing fn_ptr into asm! blocks and calling function
1 parent f8ac1a7 commit aa22a94

7 files changed

Lines changed: 219 additions & 6 deletions

File tree

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
33

44
use super::Builder;
55
use crate::abi::ConvSpirvType;
6-
use crate::builder_spirv::SpirvValue;
6+
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
77
use crate::codegen_cx::CodegenCx;
88
use crate::spirv_type::SpirvType;
99
use rspirv::dr;
@@ -127,12 +127,20 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
127127
};
128128

129129
if let Some(in_value) = in_value
130-
&& let (BackendRepr::Scalar(scalar), OperandValue::Immediate(in_value_spv)) =
131-
(in_value.layout.backend_repr, &mut in_value.val)
132-
&& let Primitive::Pointer(_) = scalar.primitive()
130+
&& let OperandValue::Immediate(in_value_spv) = &mut in_value.val
133131
{
134-
let in_value_precise_type = in_value.layout.spirv_type(self.span(), self);
135-
*in_value_spv = self.pointercast(*in_value_spv, in_value_precise_type);
132+
if let SpirvValueKind::FnAddr { function } = in_value_spv.kind
133+
&& let SpirvType::Pointer { pointee } = self.lookup_type(in_value_spv.ty)
134+
{
135+
// reference to function pointer must be unwrapped from its pointer to be used in calls
136+
*in_value_spv = function.with_type(pointee);
137+
} else if let BackendRepr::Scalar(scalar) = in_value.layout.backend_repr
138+
&& let Primitive::Pointer(_) = scalar.primitive()
139+
{
140+
// ordinary SPIR-V value
141+
let in_value_precise_type = in_value.layout.spirv_type(self.span(), self);
142+
*in_value_spv = self.pointercast(*in_value_spv, in_value_precise_type);
143+
}
136144
}
137145
if let Some(out_place) = out_place {
138146
let out_place_precise_type = out_place.layout.spirv_type(self.span(), self);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "; .*\n" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-vulkan1.0
14+
// ignore-vulkan1.1
15+
16+
use core::arch::asm;
17+
use spirv_std::glam::*;
18+
use spirv_std::spirv;
19+
20+
pub fn add_one(a: f32) -> f32 {
21+
a + 1.
22+
}
23+
24+
#[spirv(fragment)]
25+
pub fn main(a: f32, result: &mut f32) {
26+
unsafe {
27+
asm! {
28+
"%f32 = OpTypeFloat 32",
29+
"%result = OpFunctionCall %f32 {func} {a}",
30+
"OpStore {result} %result",
31+
func = in(reg) add_one,
32+
a = in(reg) a,
33+
result = in(reg) result,
34+
}
35+
}
36+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint Fragment %1 "main" %2 %3
4+
OpExecutionMode %1 OriginUpperLeft
5+
OpName %2 "a"
6+
OpName %3 "result"
7+
OpName %5 "fn_ptr_call_float::add_one"
8+
OpDecorate %2 Location 0
9+
OpDecorate %3 Location 0
10+
%6 = OpTypeFloat 32
11+
%7 = OpTypePointer Input %6
12+
%8 = OpTypePointer Output %6
13+
%9 = OpTypeVoid
14+
%10 = OpTypeFunction %9
15+
%2 = OpVariable %7 Input
16+
%11 = OpTypeFunction %6 %6
17+
%12 = OpConstant %6 1
18+
%3 = OpVariable %8 Output
19+
%1 = OpFunction %9 None %10
20+
%13 = OpLabel
21+
%14 = OpLoad %6 %2
22+
%15 = OpFunctionCall %6 %5 %14
23+
OpStore %3 %15
24+
OpNoLine
25+
OpReturn
26+
OpFunctionEnd
27+
%5 = OpFunction %6 None %11
28+
%16 = OpFunctionParameter %6
29+
%17 = OpLabel
30+
%18 = OpFAdd %6 %16 %12
31+
OpNoLine
32+
OpReturnValue %18
33+
OpFunctionEnd
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "; .*\n" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-vulkan1.0
14+
// ignore-vulkan1.1
15+
16+
use core::arch::asm;
17+
use spirv_std::glam::*;
18+
use spirv_std::spirv;
19+
20+
pub fn add_one(a: Vec4) -> Vec4 {
21+
a + 1.
22+
}
23+
24+
#[spirv(fragment)]
25+
pub fn main(a: Vec4, result: &mut Vec4) {
26+
unsafe {
27+
asm! {
28+
"%f32 = OpTypeFloat 32",
29+
"%a = OpLoad _ {a}",
30+
"%result = OpFunctionCall typeof*{result} {func} %a",
31+
"OpStore {result} %result",
32+
func = in(reg) add_one,
33+
a = in(reg) &a,
34+
result = in(reg) result,
35+
}
36+
}
37+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint Fragment %1 "main" %2 %3
4+
OpExecutionMode %1 OriginUpperLeft
5+
OpName %2 "a"
6+
OpName %3 "result"
7+
OpName %7 "fn_ptr_call_vec::add_one"
8+
OpDecorate %2 Location 0
9+
OpDecorate %3 Location 0
10+
%8 = OpTypeFloat 32
11+
%9 = OpTypeVector %8 4
12+
%10 = OpTypePointer Input %9
13+
%11 = OpTypePointer Output %9
14+
%12 = OpTypeVoid
15+
%13 = OpTypeFunction %12
16+
%2 = OpVariable %10 Input
17+
%14 = OpTypeFunction %9 %9
18+
%15 = OpConstant %8 1
19+
%3 = OpVariable %11 Output
20+
%1 = OpFunction %12 None %13
21+
%16 = OpLabel
22+
%17 = OpLoad %9 %2
23+
%18 = OpFunctionCall %9 %7 %17
24+
OpStore %3 %18
25+
OpNoLine
26+
OpReturn
27+
OpFunctionEnd
28+
%7 = OpFunction %9 None %14
29+
%19 = OpFunctionParameter %9
30+
%20 = OpLabel
31+
%21 = OpCompositeConstruct %9 %15 %15 %15 %15
32+
%22 = OpFAdd %9 %19 %21
33+
OpNoLine
34+
OpReturnValue %22
35+
OpFunctionEnd
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "; .*\n" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-vulkan1.0
14+
// ignore-vulkan1.1
15+
16+
use core::arch::asm;
17+
use spirv_std::glam::*;
18+
use spirv_std::spirv;
19+
20+
pub fn my_func() {
21+
spirv_std::arch::kill()
22+
}
23+
24+
#[spirv(fragment)]
25+
pub fn main() {
26+
unsafe {
27+
asm! {
28+
"%void = OpTypeVoid",
29+
"%ignore = OpFunctionCall %void {func}",
30+
func = in(reg) my_func,
31+
}
32+
}
33+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint Fragment %1 "main"
4+
OpExecutionMode %1 OriginUpperLeft
5+
OpName %3 "fn_ptr_call_void::main"
6+
OpName %4 "fn_ptr_call_void::my_func"
7+
OpName %5 "spirv_std::arch::kill"
8+
%6 = OpTypeVoid
9+
%7 = OpTypeFunction %6
10+
%1 = OpFunction %6 None %7
11+
%8 = OpLabel
12+
%9 = OpFunctionCall %6 %3
13+
OpNoLine
14+
OpReturn
15+
OpFunctionEnd
16+
%3 = OpFunction %6 None %7
17+
%10 = OpLabel
18+
%11 = OpFunctionCall %6 %4
19+
OpNoLine
20+
OpReturn
21+
OpFunctionEnd
22+
%4 = OpFunction %6 None %7
23+
%12 = OpLabel
24+
%13 = OpFunctionCall %6 %5
25+
OpNoLine
26+
OpUnreachable
27+
OpFunctionEnd
28+
%5 = OpFunction %6 None %7
29+
%14 = OpLabel
30+
OpKill
31+
OpFunctionEnd

0 commit comments

Comments
 (0)