Skip to content

Commit e8487f3

Browse files
[ET Device Support] Parse device info from serialized tensor in tensor_parser (#18966)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #18328 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/143/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/143/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/143/orig Differential Revision: [D97199497](https://our.internmc.facebook.com/intern/diff/D97199497/) @diff-train-skip-merge Co-authored-by: gasoonjia <gasoonjia@icloud.com>
1 parent f4019c3 commit e8487f3

5 files changed

Lines changed: 346 additions & 1 deletion

File tree

runtime/executor/tensor_parser_portable.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ Result<Tensor> parseTensor(
147147
Internal,
148148
"dim_order_to_stride returned invalid status");
149149

150+
// Extract device info from serialized tensor metadata.
151+
// Defaults to CPU/0 for backward compatibility when extra_tensor_info is
152+
// absent (e.g., older PTE files without device annotations).
153+
auto device_type = executorch::runtime::etensor::DeviceType::CPU;
154+
executorch::runtime::etensor::DeviceIndex device_index = 0;
155+
if (s_tensor->extra_tensor_info() != nullptr) {
156+
device_type = static_cast<executorch::runtime::etensor::DeviceType>(
157+
s_tensor->extra_tensor_info()->device_type());
158+
device_index = static_cast<executorch::runtime::etensor::DeviceIndex>(
159+
s_tensor->extra_tensor_info()->device_index());
160+
}
161+
150162
auto* tensor_impl = method_allocator->allocateInstance<TensorImpl>();
151163
if (tensor_impl == nullptr) {
152164
return Error::MemoryAllocationFailed;
@@ -161,7 +173,9 @@ Result<Tensor> parseTensor(
161173
/*data=*/nullptr,
162174
dim_order,
163175
strides,
164-
dynamism);
176+
dynamism,
177+
device_type,
178+
device_index);
165179

166180
// Now that we know how big the tensor is, find and assign its memory.
167181
Result<void*> data_ptr = getTensorDataPtr(

runtime/executor/test/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,19 @@ def define_common_targets(is_fbcode = False):
312312
],
313313
env = modules_env,
314314
)
315+
316+
runtime.cxx_test(
317+
name = "tensor_parser_device_test",
318+
srcs = [
319+
"tensor_parser_device_test.cpp",
320+
],
321+
deps = [
322+
":managed_memory_manager",
323+
"//executorch/runtime/executor:program",
324+
"//executorch/extension/data_loader:file_data_loader",
325+
"//executorch/schema:program",
326+
],
327+
env = {
328+
"ET_MODULE_ADD_WITH_DEVICE_PATH": "$(location fbcode//executorch/test/models:exported_program_with_device_info[ModuleAddWithDevice.pte])",
329+
},
330+
)
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* Tests that device info (device_type) is correctly parsed from serialized
11+
* tensors in .pte files into TensorImpl at runtime.
12+
*
13+
* Uses a .pte exported with DeviceAwarePartitioner (CUDA device annotation)
14+
* so that delegate output tensors carry device_type=CUDA in ExtraTensorInfo.
15+
*/
16+
17+
#include <executorch/runtime/executor/tensor_parser.h>
18+
19+
#include <executorch/extension/data_loader/file_data_loader.h>
20+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
21+
#include <executorch/runtime/executor/test/managed_memory_manager.h>
22+
#include <executorch/schema/program_generated.h>
23+
24+
#include <gtest/gtest.h>
25+
26+
using executorch::aten::Tensor;
27+
using executorch::runtime::Error;
28+
using executorch::runtime::Program;
29+
using executorch::runtime::Result;
30+
using executorch::runtime::deserialization::parseTensor;
31+
using executorch::runtime::testing::ManagedMemoryManager;
32+
using torch::executor::util::FileDataLoader;
33+
34+
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
35+
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
36+
37+
namespace executorch {
38+
namespace runtime {
39+
namespace testing {
40+
class ProgramTestFriend final {
41+
public:
42+
const static executorch_flatbuffer::Program* GetInternalProgram(
43+
const Program* program) {
44+
return program->internal_program_;
45+
}
46+
};
47+
} // namespace testing
48+
} // namespace runtime
49+
} // namespace executorch
50+
51+
using executorch::runtime::testing::ProgramTestFriend;
52+
53+
class TensorParserDeviceTest : public ::testing::Test {
54+
protected:
55+
void SetUp() override {
56+
const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH");
57+
ASSERT_NE(path, nullptr)
58+
<< "ET_MODULE_ADD_WITH_DEVICE_PATH env var not set";
59+
Result<FileDataLoader> loader = FileDataLoader::from(path);
60+
ASSERT_EQ(loader.error(), Error::Ok);
61+
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
62+
}
63+
64+
std::unique_ptr<FileDataLoader> loader_;
65+
};
66+
67+
TEST_F(TensorParserDeviceTest, CUDADeviceParsedFromPteFile) {
68+
Result<Program> program =
69+
Program::load(loader_.get(), Program::Verification::Minimal);
70+
ASSERT_EQ(program.error(), Error::Ok);
71+
72+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
73+
74+
const executorch_flatbuffer::Program* internal_program =
75+
ProgramTestFriend::GetInternalProgram(&program.get());
76+
auto* execution_plan =
77+
internal_program->execution_plan()->GetMutableObject(0);
78+
auto* flatbuffer_values = execution_plan->values();
79+
80+
int cuda_tensor_count = 0;
81+
int cpu_tensor_count = 0;
82+
83+
for (uint32_t i = 0; i < flatbuffer_values->size(); ++i) {
84+
auto* serialization_value = flatbuffer_values->Get(i);
85+
if (serialization_value->val_type() !=
86+
executorch_flatbuffer::KernelTypes::Tensor) {
87+
continue;
88+
}
89+
90+
auto* s_tensor = serialization_value->val_as_Tensor();
91+
92+
Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
93+
if (!tensor.ok()) {
94+
bool has_cuda = s_tensor->extra_tensor_info() != nullptr &&
95+
s_tensor->extra_tensor_info()->device_type() ==
96+
executorch_flatbuffer::DeviceType::CUDA;
97+
if (has_cuda) {
98+
cuda_tensor_count++;
99+
}
100+
continue;
101+
}
102+
103+
Tensor t = tensor.get();
104+
auto device_type = t.unsafeGetTensorImpl()->device_type();
105+
106+
if (device_type == executorch::runtime::etensor::DeviceType::CUDA) {
107+
cuda_tensor_count++;
108+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
109+
<< "CUDA tensor should have device_index=0";
110+
} else {
111+
EXPECT_EQ(device_type, executorch::runtime::etensor::DeviceType::CPU);
112+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
113+
<< "CPU tensor should have device_index=0";
114+
cpu_tensor_count++;
115+
}
116+
}
117+
118+
EXPECT_EQ(cuda_tensor_count, 3)
119+
<< "Expected 3 CUDA tensors (2 delegate inputs + 1 delegate output)";
120+
EXPECT_EQ(cpu_tensor_count, 0)
121+
<< "Expected 0 CPU tensors (all annotated as CUDA)";
122+
}
123+
124+
TEST_F(TensorParserDeviceTest, NonDelegatedTensorsDefaultToCPU) {
125+
Result<Program> program =
126+
Program::load(loader_.get(), Program::Verification::Minimal);
127+
ASSERT_EQ(program.error(), Error::Ok);
128+
129+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
130+
131+
const executorch_flatbuffer::Program* internal_program =
132+
ProgramTestFriend::GetInternalProgram(&program.get());
133+
auto* execution_plan =
134+
internal_program->execution_plan()->GetMutableObject(0);
135+
auto* flatbuffer_values = execution_plan->values();
136+
137+
for (uint32_t i = 0; i < flatbuffer_values->size(); ++i) {
138+
auto* serialization_value = flatbuffer_values->Get(i);
139+
if (serialization_value->val_type() !=
140+
executorch_flatbuffer::KernelTypes::Tensor) {
141+
continue;
142+
}
143+
144+
auto* s_tensor = serialization_value->val_as_Tensor();
145+
bool has_cuda_device = s_tensor->extra_tensor_info() != nullptr &&
146+
s_tensor->extra_tensor_info()->device_type() ==
147+
executorch_flatbuffer::DeviceType::CUDA;
148+
149+
// Only check tensors that are NOT annotated as CUDA
150+
if (has_cuda_device) {
151+
continue;
152+
}
153+
154+
Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
155+
if (!tensor.ok()) {
156+
continue;
157+
}
158+
159+
Tensor t = tensor.get();
160+
EXPECT_EQ(
161+
t.unsafeGetTensorImpl()->device_type(),
162+
executorch::runtime::etensor::DeviceType::CPU)
163+
<< "Tensor at index " << i
164+
<< " without CUDA annotation should default to CPU";
165+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
166+
<< "Tensor at index " << i
167+
<< " without device annotation should have device_index=0";
168+
}
169+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
"""Exports a simple model with device-annotated tensors for C++ testing.
10+
11+
Uses DeviceAwarePartitioner (BackendWithCompilerDemo + target_device=cuda:0)
12+
so that delegate output tensors are annotated with CUDA device in the .pte.
13+
"""
14+
15+
import argparse
16+
import os
17+
from typing import Dict, final
18+
19+
import torch
20+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
21+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
22+
generate_pattern_op_partitions,
23+
)
24+
from executorch.exir.backend.compile_spec_schema import CompileSpec
25+
from executorch.exir.backend.partitioner import (
26+
DelegationSpec,
27+
Partitioner,
28+
PartitionResult,
29+
)
30+
from executorch.exir.backend.test.backend_with_compiler_demo import (
31+
BackendWithCompilerDemo,
32+
)
33+
from executorch.exir.dialects._ops import ops as exir_ops
34+
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY
35+
from torch import nn
36+
from torch.export import export
37+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
38+
39+
40+
class _AddOperatorSupport(OperatorSupportBase):
41+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
42+
return node.op == "call_function" and node.target in [
43+
exir_ops.edge.aten.add.Tensor,
44+
]
45+
46+
47+
@final
48+
class _DeviceAwarePartitioner(Partitioner):
49+
"""Partitioner that tags add ops for delegation with target_device=cuda:0."""
50+
51+
def __init__(self) -> None:
52+
super().__init__()
53+
self.delegation_spec = DelegationSpec(
54+
BackendWithCompilerDemo.__name__,
55+
[
56+
CompileSpec("max_value", bytes([4])),
57+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
58+
],
59+
)
60+
61+
def partition(self, exported_program) -> PartitionResult:
62+
partition_tags: Dict[str, DelegationSpec] = {}
63+
partition_list = generate_pattern_op_partitions(
64+
exported_program.graph_module,
65+
op_support=any_chain(_AddOperatorSupport()),
66+
)
67+
for partition in partition_list:
68+
for node in partition.nodes:
69+
tag = f"tag{partition.id}"
70+
node.meta["delegation_tag"] = tag
71+
partition_tags[tag] = self.delegation_spec
72+
return PartitionResult(
73+
tagged_exported_program=exported_program,
74+
partition_tags=partition_tags,
75+
)
76+
77+
78+
class ModuleAddWithDevice(nn.Module):
79+
"""Simple add model — the add op will be delegated with CUDA device annotation."""
80+
81+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
82+
return torch.add(a, b)
83+
84+
def get_random_inputs(self):
85+
return (torch.randn(2, 2), torch.randn(2, 2))
86+
87+
88+
def main() -> None:
89+
parser = argparse.ArgumentParser()
90+
parser.add_argument("--outdir", type=str, required=True)
91+
args = parser.parse_args()
92+
93+
torch.manual_seed(0)
94+
model = ModuleAddWithDevice()
95+
inputs = model.get_random_inputs()
96+
97+
edge = to_edge(
98+
export(model, inputs),
99+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
100+
)
101+
lowered = edge.to_backend(_DeviceAwarePartitioner())
102+
et_prog = lowered.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False))
103+
104+
os.makedirs(args.outdir, exist_ok=True)
105+
outfile = os.path.join(args.outdir, "ModuleAddWithDevice.pte")
106+
107+
with open(outfile, "wb") as fp:
108+
fp.write(et_prog.buffer)
109+
print(f"Exported ModuleAddWithDevice to {outfile}")
110+
111+
112+
if __name__ == "__main__":
113+
main()

test/models/targets.bzl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,27 @@ def define_common_targets():
141141
visibility = [], # Private
142142
)
143143

144+
runtime.python_library(
145+
name = "export_program_with_device_info_lib",
146+
srcs = ["export_program_with_device_info.py"],
147+
deps = [
148+
"//caffe2:torch",
149+
"//executorch/exir/backend/test:backend_with_compiler_demo",
150+
"//executorch/exir:lib",
151+
],
152+
visibility = [], # Private
153+
)
154+
155+
runtime.python_binary(
156+
name = "export_program_with_device_info",
157+
main_module = "executorch.test.models.export_program_with_device_info",
158+
par_style = "xar",
159+
deps = [
160+
":export_program_with_device_info_lib",
161+
],
162+
visibility = [], # Private
163+
)
164+
144165
runtime.python_binary(
145166
name = "export_delegated_program",
146167
main_module = "executorch.test.models.export_delegated_program",
@@ -196,6 +217,18 @@ def define_common_targets():
196217
],
197218
)
198219

220+
runtime.genrule(
221+
name = "exported_program_with_device_info",
222+
cmd = "$(exe :export_program_with_device_info) --outdir $OUT",
223+
outs = {
224+
"ModuleAddWithDevice.pte": ["ModuleAddWithDevice.pte"],
225+
},
226+
default_outs = ["."],
227+
visibility = [
228+
"//executorch/runtime/executor/test/...",
229+
],
230+
)
231+
199232
runtime.genrule(
200233
name = "exported_xnnp_delegated_programs",
201234
cmd = "$(exe :export_delegated_program)" +

0 commit comments

Comments
 (0)