|
24 | 24 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h> |
25 | 25 | #include <executorch/extension/module/bundled_module.h> |
26 | 26 | #include <executorch/extension/module/module.h> |
| 27 | +#include <executorch/extension/pybindings/pybindings_data_loader.h> |
27 | 28 | #include <executorch/extension/tensor/tensor_ptr.h> |
28 | 29 | #include <executorch/extension/tensor/tensor_ptr_maker.h> |
29 | 30 | #include <executorch/extension/threadpool/threadpool.h> |
@@ -85,6 +86,8 @@ using ::executorch::extension::BufferDataLoader; |
85 | 86 | using ::executorch::extension::MallocMemoryAllocator; |
86 | 87 | using ::executorch::extension::MmapDataLoader; |
87 | 88 | using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule; |
| 89 | +using ::executorch::extension::pybindings::PyDataLoader; |
| 90 | +using ::executorch::extension::pybindings::SharedPtrDataLoader; |
88 | 91 | using ::executorch::runtime::ArrayRef; |
89 | 92 | using ::executorch::runtime::DataLoader; |
90 | 93 | using ::executorch::runtime::Error; |
@@ -246,6 +249,29 @@ inline std::unique_ptr<Module> load_module_from_buffer_with_data_file( |
246 | 249 | std::move(data_loader)); |
247 | 250 | } |
248 | 251 |
|
| 252 | +inline std::unique_ptr<Module> load_module_from_data_loader( |
| 253 | + std::shared_ptr<PyDataLoader> loader, |
| 254 | + std::optional<const std::string> data_map_path, |
| 255 | + std::unique_ptr<runtime::EventTracer> event_tracer) { |
| 256 | + EXECUTORCH_SCOPE_PROF("load_module_from_data_loader"); |
| 257 | + |
| 258 | + if (data_map_path.has_value()) { |
| 259 | + auto data_map_loader = loader_from_file(data_map_path.value()); |
| 260 | + return std::make_unique<Module>( |
| 261 | + loader->make_delegating_loader(), |
| 262 | + nullptr, // memory_allocator |
| 263 | + nullptr, // temp_allocator |
| 264 | + std::move(event_tracer), // event_tracer |
| 265 | + std::move(data_map_loader)); // data_map_loader |
| 266 | + } |
| 267 | + return std::make_unique<Module>( |
| 268 | + loader->make_delegating_loader(), |
| 269 | + nullptr, // memory_allocator |
| 270 | + nullptr, // temp_allocator |
| 271 | + std::move(event_tracer), // event_tracer |
| 272 | + nullptr); // data_map_loader |
| 273 | +} |
| 274 | + |
249 | 275 | inline py::list get_outputs_as_py_list( |
250 | 276 | const std::vector<EValue>& outputs, |
251 | 277 | bool clone_outputs = true) { |
@@ -601,6 +627,17 @@ struct PyModule final { |
601 | 627 | setup_event_tracer(enable_etdump, debug_buffer_size), |
602 | 628 | program_verification)) {} |
603 | 629 |
|
| 630 | + explicit PyModule( |
| 631 | + std::shared_ptr<PyDataLoader> loader, |
| 632 | + std::optional<const std::string> data_path, |
| 633 | + bool enable_etdump, |
| 634 | + size_t debug_buffer_size = 0) |
| 635 | + : debug_buffer_size_(debug_buffer_size), |
| 636 | + module_(load_module_from_data_loader( |
| 637 | + std::move(loader), |
| 638 | + data_path, |
| 639 | + setup_event_tracer(enable_etdump, debug_buffer_size))) {} |
| 640 | + |
604 | 641 | PyModule(const PyModule&) = delete; |
605 | 642 | PyModule& operator=(const PyModule&) = delete; |
606 | 643 | PyModule(PyModule&&) = default; |
@@ -676,6 +713,17 @@ struct PyModule final { |
676 | 713 | Program::Verification::InternalConsistency); |
677 | 714 | } |
678 | 715 |
|
| 716 | + // Load from an external data loader. |
| 717 | + // This allows external libraries (like PTEZ) to provide custom data loaders. |
| 718 | + static std::unique_ptr<PyModule> load_from_data_loader( |
| 719 | + std::shared_ptr<PyDataLoader> loader, |
| 720 | + std::optional<const std::string> data_path, |
| 721 | + bool enable_etdump, |
| 722 | + size_t debug_buffer_size = 0) { |
| 723 | + return std::make_unique<PyModule>( |
| 724 | + std::move(loader), data_path, enable_etdump, debug_buffer_size); |
| 725 | + } |
| 726 | + |
679 | 727 | py::list run_method( |
680 | 728 | const std::string& method_name, |
681 | 729 | const py::sequence& inputs, |
@@ -1529,6 +1577,20 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { |
1529 | 1577 | py::arg("buffer"), |
1530 | 1578 | py::arg("non_const_pool_size") = kDEFAULT_BUNDLED_INPUT_POOL_SIZE, |
1531 | 1579 | call_guard); |
| 1580 | + |
| 1581 | + // Import the PyDataLoader type from the shared module. |
| 1582 | + // This ensures the type is registered once and shared across all modules. |
| 1583 | + py::module_::import("executorch.extension.pybindings.data_loader"); |
| 1584 | + |
| 1585 | + m.def( |
| 1586 | + "_load_for_executorch_from_data_loader", |
| 1587 | + &PyModule::load_from_data_loader, |
| 1588 | + py::arg("loader"), |
| 1589 | + py::arg("data_path") = py::none(), |
| 1590 | + py::arg("enable_etdump") = false, |
| 1591 | + py::arg("debug_buffer_size") = 0, |
| 1592 | + call_guard); |
| 1593 | + |
1532 | 1594 | m.def( |
1533 | 1595 | "_dump_profile_results", |
1534 | 1596 | []() { |
|
0 commit comments