@@ -19,10 +19,14 @@ limitations under the License.
1919#include < ATen/core/TensorBody.h>
2020#include < absl/strings/match.h>
2121#include < caffe2/serialize/inline_container.h>
22+ #include < fcntl.h>
2223#include < glog/logging.h>
24+ #include < sys/mman.h>
25+ #include < sys/stat.h>
2326#include < torch/csrc/jit/serialization/import_read.h>
2427#include < torch/csrc/jit/serialization/storage_context.h>
2528#include < torch/torch.h>
29+ #include < unistd.h>
2630
2731#include < memory>
2832
@@ -79,6 +83,40 @@ std::vector<int64_t> get_sizes(const View* view) {
7983 return sizes;
8084}
8185
86+ std::unique_ptr<MemoryMapping> create_memory_mapping (const char * weights_file) {
87+ auto mapping = std::make_unique<MemoryMapping>();
88+
89+ mapping->fd = open (weights_file, O_RDONLY);
90+ if (mapping->fd == -1 ) {
91+ LOG (FATAL) << " Failed to open weight file: " << weights_file;
92+ }
93+
94+ struct stat sb;
95+ if (fstat (mapping->fd , &sb) == -1 ) {
96+ LOG (FATAL) << " Failed to get file size for weight file: " << weights_file;
97+ }
98+ mapping->mapped_size = sb.st_size ;
99+
100+ mapping->mapped_addr =
101+ mmap (NULL , mapping->mapped_size , PROT_READ, MAP_PRIVATE, mapping->fd , 0 );
102+ if (mapping->mapped_addr == MAP_FAILED) {
103+ LOG (FATAL) << " Failed to map file: " << weights_file;
104+ }
105+
106+ return mapping;
107+ }
108+
109+ void destroy_memory_mapping (MemoryMapping* mapping) {
110+ if (mapping) {
111+ if (mapping->mapped_addr != MAP_FAILED) {
112+ munmap (mapping->mapped_addr , mapping->mapped_size );
113+ }
114+ if (mapping->fd != -1 ) {
115+ close (mapping->fd );
116+ }
117+ free (mapping);
118+ }
119+ }
82120} // namespace
83121
84122StateDict::StateDict (std::unordered_map<std::string, torch::Tensor> dict,
@@ -145,25 +183,31 @@ StateDict StateDict::get_dict_with_prefix(
145183}
146184
147185StateDictFromSafeTensor::StateDictFromSafeTensor (
148- std::unique_ptr<folly:: MemoryMapping> mem_map,
186+ std::unique_ptr<MemoryMapping> mem_map,
149187 std::unordered_map<std::string, torch::Tensor> dict)
150188 : StateDict(std::move(dict)), mem_map_(std::move(mem_map)) {}
151189
190+ StateDictFromSafeTensor::~StateDictFromSafeTensor () {
191+ destroy_memory_mapping (mem_map_.release ());
192+ }
193+
152194std::unique_ptr<StateDict> StateDictFromSafeTensor::load (
153195 const std::string& weights_file) {
154- folly::MemoryMapping::Options options;
155- options.setPrefault (true ).setReadable (true );
156- auto mem_map = std::make_unique<folly::MemoryMapping>(weights_file.c_str (),
157- 0 , // offset
158- -1 , // length
159- options);
196+ std::unique_ptr<MemoryMapping> mem_map = std::unique_ptr<MemoryMapping>(
197+ create_memory_mapping (weights_file.c_str ()));
198+
199+ if (!mem_map) {
200+ LOG (FATAL) << " Failed to create memory mapping for " << weights_file;
201+ }
202+
160203 if (util::get_bool_env (ENV_MLOCK_ENABLED, DEFAULT_MLOCK_ENABLED)) {
161- mem_map->mlock (folly::MemoryMapping::LockMode::MUST_LOCK);
204+ if (mlock (mem_map->mapped_addr , mem_map->mapped_size ) == -1 ) {
205+ LOG (FATAL) << " Failed to lock memory for file: " << weights_file;
206+ }
162207 }
163- const folly::ByteRange content = mem_map->range ();
164- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
165- const uint8_t * data = reinterpret_cast <const uint8_t *>(content.data ());
166- const size_t size = content.size ();
208+
209+ const uint8_t * data = static_cast <const uint8_t *>(mem_map->mapped_addr );
210+ const size_t size = mem_map->mapped_size ;
167211
168212 std::unordered_map<std::string, torch::Tensor> dict;
169213 // safetensors
0 commit comments