Skip to content

Commit ab09a45

Browse files
authored
bugfix: remove folly::MemoryMapping to optimize model loading performance.
1 parent 4533086 commit ab09a45

File tree

2 files changed

+68
-15
lines changed

2 files changed

+68
-15
lines changed

xllm/core/framework/state_dict/state_dict.cpp

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ limitations under the License.
1919
#include <ATen/core/TensorBody.h>
2020
#include <absl/strings/match.h>
2121
#include <caffe2/serialize/inline_container.h>
22+
#include <fcntl.h>
2223
#include <glog/logging.h>
24+
#include <sys/mman.h>
25+
#include <sys/stat.h>
2326
#include <torch/csrc/jit/serialization/import_read.h>
2427
#include <torch/csrc/jit/serialization/storage_context.h>
2528
#include <torch/torch.h>
29+
#include <unistd.h>
2630

2731
#include <memory>
2832

@@ -79,6 +83,40 @@ std::vector<int64_t> get_sizes(const View* view) {
7983
return sizes;
8084
}
8185

86+
std::unique_ptr<MemoryMapping> create_memory_mapping(const char* weights_file) {
87+
auto mapping = std::make_unique<MemoryMapping>();
88+
89+
mapping->fd = open(weights_file, O_RDONLY);
90+
if (mapping->fd == -1) {
91+
LOG(FATAL) << "Failed to open weight file: " << weights_file;
92+
}
93+
94+
struct stat sb;
95+
if (fstat(mapping->fd, &sb) == -1) {
96+
LOG(FATAL) << "Failed to get file size for weight file: " << weights_file;
97+
}
98+
mapping->mapped_size = sb.st_size;
99+
100+
mapping->mapped_addr =
101+
mmap(NULL, mapping->mapped_size, PROT_READ, MAP_PRIVATE, mapping->fd, 0);
102+
if (mapping->mapped_addr == MAP_FAILED) {
103+
LOG(FATAL) << "Failed to map file: " << weights_file;
104+
}
105+
106+
return mapping;
107+
}
108+
109+
void destroy_memory_mapping(MemoryMapping* mapping) {
110+
if (mapping) {
111+
if (mapping->mapped_addr != MAP_FAILED) {
112+
munmap(mapping->mapped_addr, mapping->mapped_size);
113+
}
114+
if (mapping->fd != -1) {
115+
close(mapping->fd);
116+
}
117+
free(mapping);
118+
}
119+
}
82120
} // namespace
83121

84122
StateDict::StateDict(std::unordered_map<std::string, torch::Tensor> dict,
@@ -145,25 +183,31 @@ StateDict StateDict::get_dict_with_prefix(
145183
}
146184

147185
StateDictFromSafeTensor::StateDictFromSafeTensor(
148-
std::unique_ptr<folly::MemoryMapping> mem_map,
186+
std::unique_ptr<MemoryMapping> mem_map,
149187
std::unordered_map<std::string, torch::Tensor> dict)
150188
: StateDict(std::move(dict)), mem_map_(std::move(mem_map)) {}
151189

190+
StateDictFromSafeTensor::~StateDictFromSafeTensor() {
191+
destroy_memory_mapping(mem_map_.release());
192+
}
193+
152194
std::unique_ptr<StateDict> StateDictFromSafeTensor::load(
153195
const std::string& weights_file) {
154-
folly::MemoryMapping::Options options;
155-
options.setPrefault(true).setReadable(true);
156-
auto mem_map = std::make_unique<folly::MemoryMapping>(weights_file.c_str(),
157-
0, // offset
158-
-1, // length
159-
options);
196+
std::unique_ptr<MemoryMapping> mem_map = std::unique_ptr<MemoryMapping>(
197+
create_memory_mapping(weights_file.c_str()));
198+
199+
if (!mem_map) {
200+
LOG(FATAL) << "Failed to create memory mapping for " << weights_file;
201+
}
202+
160203
if (util::get_bool_env(ENV_MLOCK_ENABLED, DEFAULT_MLOCK_ENABLED)) {
161-
mem_map->mlock(folly::MemoryMapping::LockMode::MUST_LOCK);
204+
if (mlock(mem_map->mapped_addr, mem_map->mapped_size) == -1) {
205+
LOG(FATAL) << "Failed to lock memory for file: " << weights_file;
206+
}
162207
}
163-
const folly::ByteRange content = mem_map->range();
164-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
165-
const uint8_t* data = reinterpret_cast<const uint8_t*>(content.data());
166-
const size_t size = content.size();
208+
209+
const uint8_t* data = static_cast<const uint8_t*>(mem_map->mapped_addr);
210+
const size_t size = mem_map->mapped_size;
167211

168212
std::unordered_map<std::string, torch::Tensor> dict;
169213
// safetensors

xllm/core/framework/state_dict/state_dict.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616

1717
#pragma once
1818
#include <c10/core/DeviceType.h>
19-
#include <folly/system/MemoryMapping.h>
2019
#include <torch/torch.h>
2120

2221
#include <memory>
@@ -66,16 +65,26 @@ class StateDict {
6665
std::string prefix_;
6766
};
6867

68+
struct MemoryMapping {
69+
void* mapped_addr = nullptr;
70+
size_t mapped_size = 0;
71+
int fd = -1;
72+
73+
MemoryMapping() = default;
74+
};
75+
6976
class StateDictFromSafeTensor : public StateDict {
7077
public:
71-
StateDictFromSafeTensor(std::unique_ptr<folly::MemoryMapping> mem_map,
78+
StateDictFromSafeTensor(std::unique_ptr<MemoryMapping> mem_map,
7279
std::unordered_map<std::string, torch::Tensor> dict);
7380

81+
~StateDictFromSafeTensor();
82+
7483
static std::unique_ptr<StateDict> load(const std::string& weights_file);
7584

7685
private:
7786
// memory mapping for safetensors
78-
std::unique_ptr<folly::MemoryMapping> mem_map_;
87+
std::unique_ptr<MemoryMapping> mem_map_;
7988
};
8089

8190
} // namespace xllm

0 commit comments

Comments
 (0)