Skip to content

Commit 5d4eff8

Browse files
Expose hooks to load ET pybindings with your own data loader
Differential Revision: D92431122 Pull Request resolved: #17255
1 parent 19e8b68 commit 5d4eff8

9 files changed

Lines changed: 254 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,18 @@ if(EXECUTORCH_BUILD_PYBIND)
10441044
EXPORT ExecuTorchTargets
10451045
LIBRARY DESTINATION executorch/extension/pybindings
10461046
)
1047+
1048+
# pybind data_loader module - provides PyDataLoader type for external
1049+
# pybinding extensions to create custom data loaders
1050+
pybind11_add_module(
1051+
data_loader SHARED extension/pybindings/pybindings_data_loader.cpp
1052+
)
1053+
target_include_directories(data_loader PRIVATE ${_common_include_directories})
1054+
target_compile_options(data_loader PUBLIC ${_pybind_compile_options})
1055+
target_link_libraries(data_loader PRIVATE executorch)
1056+
install(TARGETS data_loader
1057+
LIBRARY DESTINATION executorch/extension/pybindings
1058+
)
10471059
endif()
10481060

10491061
if(EXECUTORCH_BUILD_WASM)

extension/pybindings/BUCK

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
3+
load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension")
24
# Any targets that should be shared between fbcode and xplat must be defined in
35
# targets.bzl. This file can contain fbcode-only targets.
46

@@ -69,3 +71,33 @@ fbcode_target(_kind = runtime.python_library,
6971
"//executorch/exir:_warnings",
7072
],
7173
)
74+
75+
# Header-only library that provides PyDataLoader for external pybinding extensions.
76+
# This allows external libraries (like PTEZ) to create custom data loaders that can
77+
# be passed to _load_for_executorch_from_data_loader().
78+
fbcode_target(
79+
_kind = cpp_library,
80+
name = "data_loader_types",
81+
headers = ["pybindings_data_loader.h"],
82+
exported_deps = [
83+
"//executorch/runtime/core:core",
84+
],
85+
visibility = ["PUBLIC"],
86+
)
87+
88+
# Python extension that registers the PyDataLoader type.
89+
# This allows external libraries to create PyDataLoader instances without
90+
# importing the full core pybindings.
91+
fbcode_target(
92+
_kind = cpp_python_extension,
93+
name = "data_loader",
94+
srcs = ["pybindings_data_loader.cpp"],
95+
base_module = "executorch.extension.pybindings",
96+
deps = [
97+
":data_loader_types",
98+
],
99+
external_deps = [
100+
"pybind11",
101+
],
102+
visibility = ["PUBLIC"],
103+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
class PyDataLoader:
10+
"""Pybind11 wrapper for DataLoader."""
11+
12+
...

extension/pybindings/pybindings.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2525
#include <executorch/extension/module/bundled_module.h>
2626
#include <executorch/extension/module/module.h>
27+
#include <executorch/extension/pybindings/pybindings_data_loader.h>
2728
#include <executorch/extension/tensor/tensor_ptr.h>
2829
#include <executorch/extension/tensor/tensor_ptr_maker.h>
2930
#include <executorch/extension/threadpool/threadpool.h>
@@ -85,6 +86,8 @@ using ::executorch::extension::BufferDataLoader;
8586
using ::executorch::extension::MallocMemoryAllocator;
8687
using ::executorch::extension::MmapDataLoader;
8788
using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
89+
using ::executorch::extension::pybindings::PyDataLoader;
90+
using ::executorch::extension::pybindings::SharedPtrDataLoader;
8891
using ::executorch::runtime::ArrayRef;
8992
using ::executorch::runtime::DataLoader;
9093
using ::executorch::runtime::Error;
@@ -246,6 +249,29 @@ inline std::unique_ptr<Module> load_module_from_buffer_with_data_file(
246249
std::move(data_loader));
247250
}
248251

252+
inline std::unique_ptr<Module> load_module_from_data_loader(
253+
std::shared_ptr<PyDataLoader> loader,
254+
std::optional<const std::string> data_map_path,
255+
std::unique_ptr<runtime::EventTracer> event_tracer) {
256+
EXECUTORCH_SCOPE_PROF("load_module_from_data_loader");
257+
258+
if (data_map_path.has_value()) {
259+
auto data_map_loader = loader_from_file(data_map_path.value());
260+
return std::make_unique<Module>(
261+
loader->make_delegating_loader(),
262+
nullptr, // memory_allocator
263+
nullptr, // temp_allocator
264+
std::move(event_tracer), // event_tracer
265+
std::move(data_map_loader)); // data_map_loader
266+
}
267+
return std::make_unique<Module>(
268+
loader->make_delegating_loader(),
269+
nullptr, // memory_allocator
270+
nullptr, // temp_allocator
271+
std::move(event_tracer), // event_tracer
272+
nullptr); // data_map_loader
273+
}
274+
249275
inline py::list get_outputs_as_py_list(
250276
const std::vector<EValue>& outputs,
251277
bool clone_outputs = true) {
@@ -601,6 +627,17 @@ struct PyModule final {
601627
setup_event_tracer(enable_etdump, debug_buffer_size),
602628
program_verification)) {}
603629

630+
explicit PyModule(
631+
std::shared_ptr<PyDataLoader> loader,
632+
std::optional<const std::string> data_path,
633+
bool enable_etdump,
634+
size_t debug_buffer_size = 0)
635+
: debug_buffer_size_(debug_buffer_size),
636+
module_(load_module_from_data_loader(
637+
std::move(loader),
638+
data_path,
639+
setup_event_tracer(enable_etdump, debug_buffer_size))) {}
640+
604641
PyModule(const PyModule&) = delete;
605642
PyModule& operator=(const PyModule&) = delete;
606643
PyModule(PyModule&&) = default;
@@ -676,6 +713,17 @@ struct PyModule final {
676713
Program::Verification::InternalConsistency);
677714
}
678715

716+
// Load from an external data loader.
717+
// This allows external libraries (like PTEZ) to provide custom data loaders.
718+
static std::unique_ptr<PyModule> load_from_data_loader(
719+
std::shared_ptr<PyDataLoader> loader,
720+
std::optional<const std::string> data_path,
721+
bool enable_etdump,
722+
size_t debug_buffer_size = 0) {
723+
return std::make_unique<PyModule>(
724+
std::move(loader), data_path, enable_etdump, debug_buffer_size);
725+
}
726+
679727
py::list run_method(
680728
const std::string& method_name,
681729
const py::sequence& inputs,
@@ -1529,6 +1577,20 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
15291577
py::arg("buffer"),
15301578
py::arg("non_const_pool_size") = kDEFAULT_BUNDLED_INPUT_POOL_SIZE,
15311579
call_guard);
1580+
1581+
// Import the PyDataLoader type from the shared module.
1582+
// This ensures the type is registered once and shared across all modules.
1583+
py::module_::import("executorch.extension.pybindings.data_loader");
1584+
1585+
m.def(
1586+
"_load_for_executorch_from_data_loader",
1587+
&PyModule::load_from_data_loader,
1588+
py::arg("loader"),
1589+
py::arg("data_path") = py::none(),
1590+
py::arg("enable_etdump") = false,
1591+
py::arg("debug_buffer_size") = 0,
1592+
call_guard);
1593+
15321594
m.def(
15331595
"_dump_profile_results",
15341596
[]() {

extension/pybindings/pybindings.pyi

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ class MethodMeta:
204204

205205
def __repr__(self) -> str: ...
206206

207+
# Re-export PyDataLoader from the shared module for backward compatibility.
208+
from executorch.extension.pybindings.data_loader import PyDataLoader as PyDataLoader
209+
207210
@experimental("This API is experimental and subject to change without notice.")
208211
def _load_for_executorch(
209212
program_path: str,
@@ -265,6 +268,33 @@ def _load_for_executorch_from_bundled_program(
265268
"""
266269
...
267270

271+
@experimental("This API is experimental and subject to change without notice.")
272+
def _load_for_executorch_from_data_loader(
273+
loader: PyDataLoader,
274+
data_path: Optional[str] = None,
275+
enable_etdump: bool = False,
276+
debug_buffer_size: int = 0,
277+
) -> ExecuTorchModule:
278+
"""Load an ExecuTorch Program from a PyDataLoader.
279+
280+
This function allows external libraries to provide custom data loaders
281+
(e.g., for compressed files) and load programs using them.
282+
283+
.. warning::
284+
285+
This API is experimental and subject to change without notice.
286+
287+
Args:
288+
loader: A PyDataLoader wrapping a custom DataLoader implementation.
289+
data_path: Optional path to a data file (e.g., for external weights).
290+
enable_etdump: If true, enables an ETDump which can store profiling information.
291+
debug_buffer_size: If non-zero, enables a debug buffer for intermediate results.
292+
293+
Returns:
294+
An ExecuTorchModule ready for execution.
295+
"""
296+
...
297+
268298
@experimental("This API is experimental and subject to change without notice.")
269299
def _load_bundled_program_from_buffer(
270300
buffer: bytes, non_const_pool_size: int = ...
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pybind11/pybind11.h>
10+
11+
#include <executorch/extension/pybindings/pybindings_data_loader.h>
12+
13+
namespace py = pybind11;
14+
15+
using ::executorch::extension::pybindings::PyDataLoader;
16+
17+
PYBIND11_MODULE(data_loader, m) {
18+
py::class_<PyDataLoader, std::shared_ptr<PyDataLoader>>(m, "PyDataLoader");
19+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
13+
#include <executorch/runtime/core/data_loader.h>
14+
#include <executorch/runtime/core/freeable_buffer.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/platform/compiler.h>
17+
18+
namespace executorch {
19+
namespace extension {
20+
namespace pybindings {
21+
22+
/// DataLoader wrapper holding a shared_ptr, allowing sharing between Python
23+
/// and C++ while Module takes ownership via unique_ptr.
24+
class SharedPtrDataLoader final : public runtime::DataLoader {
25+
public:
26+
explicit SharedPtrDataLoader(std::shared_ptr<runtime::DataLoader> loader)
27+
: loader_(std::move(loader)) {}
28+
29+
ET_NODISCARD runtime::Result<runtime::FreeableBuffer> load(
30+
size_t offset,
31+
size_t size,
32+
const SegmentInfo& segment_info) const override {
33+
return loader_->load(offset, size, segment_info);
34+
}
35+
36+
ET_NODISCARD runtime::Result<size_t> size() const override {
37+
return loader_->size();
38+
}
39+
40+
ET_NODISCARD runtime::Error load_into(
41+
size_t offset,
42+
size_t size,
43+
const SegmentInfo& segment_info,
44+
void* buffer) const override {
45+
return loader_->load_into(offset, size, segment_info, buffer);
46+
}
47+
48+
private:
49+
std::shared_ptr<runtime::DataLoader> loader_;
50+
};
51+
52+
/// Pybind11 wrapper for DataLoader. Use shared_ptr holder type in pybind11.
53+
struct PyDataLoader {
54+
explicit PyDataLoader(std::shared_ptr<runtime::DataLoader> loader)
55+
: loader_(std::move(loader)) {}
56+
57+
PyDataLoader(const PyDataLoader&) = delete;
58+
PyDataLoader& operator=(const PyDataLoader&) = delete;
59+
PyDataLoader(PyDataLoader&&) = default;
60+
PyDataLoader& operator=(PyDataLoader&&) = default;
61+
62+
std::shared_ptr<runtime::DataLoader> get() const {
63+
return loader_;
64+
}
65+
66+
/// Creates a unique_ptr DataLoader that delegates to the shared loader.
67+
std::unique_ptr<runtime::DataLoader> make_delegating_loader() const {
68+
return std::make_unique<SharedPtrDataLoader>(loader_);
69+
}
70+
71+
private:
72+
std::shared_ptr<runtime::DataLoader> loader_;
73+
};
74+
75+
} // namespace pybindings
76+
} // namespace extension
77+
} // namespace executorch

setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ def run(self): # noqa C901
767767

768768
if cmake_cache.is_enabled("EXECUTORCH_BUILD_PYBIND"):
769769
cmake_build_args += ["--target", "portable_lib"]
770+
cmake_build_args += ["--target", "data_loader"]
770771
cmake_build_args += ["--target", "selective_build"]
771772

772773
if cmake_cache.is_enabled("EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER"):
@@ -838,6 +839,13 @@ def run(self): # noqa C901
838839
modpath="executorch.extension.pybindings._portable_lib",
839840
dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"],
840841
),
842+
# Install the data_loader pybindings extension which provides the
843+
# PyDataLoader type for external pybinding extensions.
844+
BuiltExtension(
845+
src="data_loader.cp*" if _is_windows() else "data_loader.*",
846+
modpath="executorch.extension.pybindings.data_loader",
847+
dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"],
848+
),
841849
BuiltExtension(
842850
src="extension/training/_training_lib.*", # @lint-ignore https://github.com/pytorch/executorch/blob/cb3eba0d7f630bc8cec0a9cc1df8ae2f17af3f7a/scripts/lint_xrefs.sh
843851
modpath="executorch.extension.training.pybindings._training_lib",

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibilit
6161
deps = [
6262
"//executorch/runtime/core:core",
6363
"//executorch/extension/threadpool:threadpool",
64+
"//executorch/extension/pybindings:data_loader_types",
65+
"//executorch/extension/pybindings:data_loader",
6466
] + cppdeps,
6567
external_deps = [
6668
"pybind11",

0 commit comments

Comments
 (0)