-
Notifications
You must be signed in to change notification settings - Fork 105
Expand file tree
/
Copy pathnaga_transpile.rs
More file actions
89 lines (81 loc) · 2.96 KB
/
naga_transpile.rs
File metadata and controls
89 lines (81 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use crate::codegen_cx::CodegenArgs;
use rustc_codegen_spirv_target_specs::SpirvTargetEnv;
use rustc_session::Session;
use rustc_span::ErrorGuaranteed;
use std::path::Path;
pub type NagaTranspile = fn(
sess: &Session,
cg_args: &CodegenArgs,
spv_binary: &[u32],
out_filename: &Path,
) -> Result<(), ErrorGuaranteed>;
pub fn should_transpile(sess: &Session) -> Result<Option<NagaTranspile>, ErrorGuaranteed> {
let target = SpirvTargetEnv::parse_triple(sess.opts.target_triple.tuple())
.expect("parsing should fail earlier");
let result: Result<Option<NagaTranspile>, ()> = match target {
#[cfg(feature = "naga")]
SpirvTargetEnv::Naga_Wgsl => Ok(Some(transpile::wgsl_transpile)),
#[cfg(not(feature = "naga"))]
SpirvTargetEnv::Naga_Wgsl => Err(()),
_ => Ok(None),
};
result.map_err(|_e| {
sess.dcx().err(format!(
"Target {} requires feature \"naga\" on rustc_codegen_spirv",
target.target_triple()
))
})
}
#[cfg(feature = "naga")]
mod transpile {
use crate::codegen_cx::CodegenArgs;
use naga::error::ShaderError;
use naga::valid::Capabilities;
use rustc_session::Session;
use rustc_span::ErrorGuaranteed;
use std::path::Path;
pub fn wgsl_transpile(
sess: &Session,
_cg_args: &CodegenArgs,
spv_binary: &[u32],
out_filename: &Path,
) -> Result<(), ErrorGuaranteed> {
// these should be params via spirv-builder
let opts = naga::front::spv::Options::default();
let capabilities = Capabilities::default();
let writer_flags = naga::back::wgsl::WriterFlags::empty();
let module = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(spv_binary), &opts)
.map_err(|err| {
sess.dcx().err(format!(
"Naga failed to parse spv: \n{}",
ShaderError {
source: String::new(),
label: None,
inner: Box::new(err),
}
))
})?;
let mut validator =
naga::valid::Validator::new(naga::valid::ValidationFlags::default(), capabilities);
let info = validator.validate(&module).map_err(|err| {
sess.dcx().err(format!(
"Naga validation failed: \n{}",
ShaderError {
source: String::new(),
label: None,
inner: Box::new(err),
}
))
})?;
let wgsl_dst = out_filename.with_extension("wgsl");
let wgsl = naga::back::wgsl::write_string(&module, &info, writer_flags).map_err(|err| {
sess.dcx()
.err(format!("Naga failed to write wgsl : \n{err}"))
})?;
std::fs::write(&wgsl_dst, wgsl).map_err(|err| {
sess.dcx()
.err(format!("failed to write wgsl to file: {err}"))
})?;
Ok(())
}
}