Skip to content

Commit 28f6f35

Browse files
authored
[Fix]Added the 'transferIoDirect' option (#352)
Added the 'transferIoDirect' option
1 parent 00b9d56 commit 28f6f35

File tree

9 files changed

+180
-90
lines changed

9 files changed

+180
-90
lines changed

ucm/store/nfsstore/cc/api/nfsstore.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class NFSStoreImpl : public NFSStore {
4545
status =
4646
this->transMgr_.Setup(config.transferDeviceId, config.transferStreamNumber,
4747
config.transferIoSize, config.transferBufferNumber,
48-
this->spaceMgr_.GetSpaceLayout(), config.transferTimeoutMs);
48+
this->spaceMgr_.GetSpaceLayout(), config.transferTimeoutMs, config.transferIoDirect);
4949
if (status.Failure()) {
5050
UC_ERROR("Failed({}) to setup TsfTaskManager.", status);
5151
return status.Underlying();
@@ -124,6 +124,7 @@ class NFSStoreImpl : public NFSStore {
124124
UC_INFO("Set UC::storageCapacity to {}.", config.storageCapacity);
125125
UC_INFO("Set UC::RecycleEnable to {}.", config.recycleEnable);
126126
UC_INFO("Set UC::RecycleThreshold to {}.", config.recycleThresholdRatio);
127+
UC_INFO("Set UC::IoDirect to {}.", config.transferIoDirect);
127128
}
128129

129130
private:

ucm/store/nfsstore/cc/api/nfsstore.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ class NFSStore : public CCStore {
4646
size_t storageCapacity;
4747
bool recycleEnable;
4848
float recycleThresholdRatio;
49+
bool transferIoDirect;
4950

5051
Config(const std::vector<std::string>& storageBackends, const size_t kvcacheBlockSize,
5152
const bool transferEnable)
5253
: storageBackends{storageBackends}, kvcacheBlockSize{kvcacheBlockSize},
5354
transferEnable{transferEnable}, transferDeviceId{-1}, transferStreamNumber{32},
5455
transferIoSize{262144}, transferBufferNumber{512}, transferTimeoutMs{30000},
5556
tempDumpDirEnable{false}, hotnessEnable{true}, hotnessInterval{60},
56-
storageCapacity{0}, recycleEnable{true}, recycleThresholdRatio{0.7f}
57+
storageCapacity{0}, recycleEnable{true}, recycleThresholdRatio{0.7f},
58+
transferIoDirect{false}
5759
{
5860
}
5961
};

ucm/store/nfsstore/cc/domain/trans/posix_queue.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ bool IsAligned(const T value)
3535
}
3636

3737
Status PosixQueue::Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber,
38-
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs)
38+
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect)
3939
{
4040
this->deviceId_ = deviceId;
4141
this->bufferSize_ = bufferSize;
4242
this->bufferNumber_ = bufferNumber;
4343
this->failureSet_ = failureSet;
4444
this->layout_ = layout;
45+
this->useDirect_ = useDirect;
4546
auto success =
4647
this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); })
4748
.SetWorkerFn([this](auto& shard, const auto& device) { this->Work(shard, device); })
@@ -106,7 +107,7 @@ Status PosixQueue::D2S(Task::Shard& shard, const Device& device)
106107
auto status = device->D2HSync((std::byte*)hub, (std::byte*)shard.address, shard.length);
107108
if (status.Failure()) { return status; }
108109
auto path = this->layout_->DataFilePath(shard.block, true);
109-
return File::Write(path, shard.offset, shard.length, (uintptr_t)hub);
110+
return File::Write(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_);
110111
}
111112

112113
Status PosixQueue::S2D(Task::Shard& shard, const Device& device)
@@ -118,7 +119,7 @@ Status PosixQueue::S2D(Task::Shard& shard, const Device& device)
118119
}
119120
auto hub = shard.buffer.get();
120121
auto path = this->layout_->DataFilePath(shard.block, false);
121-
auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub);
122+
auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_);
122123
if (status.Failure()) { return status; }
123124
return device->H2DAsync((std::byte*)shard.address, (std::byte*)hub, shard.length);
124125
}

ucm/store/nfsstore/cc/domain/trans/posix_queue.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ class PosixQueue : public TaskQueue {
4040
size_t bufferNumber_{0};
4141
TaskSet* failureSet_{nullptr};
4242
const SpaceLayout* layout_{nullptr};
43+
bool useDirect_{false};
4344
ThreadPool<Task::Shard, Device> backend_{};
4445

4546
public:
4647
Status Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber,
47-
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs);
48+
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false);
4849
void Push(std::list<Task::Shard>& shards) noexcept override;
4950

5051
private:

ucm/store/nfsstore/cc/domain/trans/trans_manager.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ namespace UC {
3232
class TransManager : public TaskManager {
3333
public:
3434
Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t ioSize,
35-
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs)
35+
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false)
3636
{
3737
this->timeoutMs_ = timeoutMs;
3838
auto status = Status::OK();
3939
for (size_t i = 0; i < streamNumber; i++) {
4040
auto q = std::make_shared<PosixQueue>();
4141
status =
42-
q->Setup(deviceId, ioSize, bufferNumber, &this->failureSet_, layout, timeoutMs);
42+
q->Setup(deviceId, ioSize, bufferNumber, &this->failureSet_, layout, timeoutMs, useDirect);
4343
if (status.Failure()) { break; }
4444
this->queues_.emplace_back(std::move(q));
4545
}

ucm/store/nfsstore/cpy/nfsstore.py.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ PYBIND11_MODULE(ucmnfsstore, module)
120120
config.def_readwrite("transferDeviceId", &UC::NFSStorePy::Config::transferDeviceId);
121121
config.def_readwrite("transferStreamNumber", &UC::NFSStorePy::Config::transferStreamNumber);
122122
config.def_readwrite("transferIoSize", &UC::NFSStorePy::Config::transferIoSize);
123+
config.def_readwrite("transferIoDirect", &UC::NFSStorePy::Config::transferIoDirect);
123124
config.def_readwrite("transferBufferNumber", &UC::NFSStorePy::Config::transferBufferNumber);
124125
config.def_readwrite("transferTimeoutMs", &UC::NFSStorePy::Config::transferTimeoutMs);
125126
config.def_readwrite("tempDumpDirEnable", &UC::NFSStorePy::Config::tempDumpDirEnable);

ucm/store/nfsstore/nfsstore_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, config: Dict):
5151
if transfer_enable:
5252
param.transferDeviceId = config["device"]
5353
param.transferIoSize = config["io_size"]
54+
param.transferIoDirect = config.get("transferIoDirect", False)
5455

5556
# NOTE: compatible with legacy nfsstore lib
5657
if hasattr(param, "storageCapacity"):

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 129 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# SOFTWARE.
2424
#
2525
import csv
26+
import math
2627
import os
2728
import secrets
2829
import time
@@ -35,7 +36,12 @@
3536

3637

3738
def setup(
38-
storage_backends, block_size, device_id, io_size, transferStreamNumber
39+
storage_backends,
40+
block_size,
41+
device_id,
42+
io_size,
43+
transferStreamNumber,
44+
transferIoDirect,
3945
) -> UcmKVStoreBase:
4046
config = {
4147
"storage_backends": storage_backends,
@@ -44,19 +50,41 @@ def setup(
4450
"device": device_id,
4551
"io_size": io_size,
4652
"transferStreamNumber": transferStreamNumber,
53+
"transferIoDirect": transferIoDirect,
4754
}
4855
return UcmNfsStore(config)
4956

5057

58+
def make_aligned_tensor(shape, dtype, device, alignment=4096):
59+
numl = math.prod(shape)
60+
dtype_size = torch.tensor(1, dtype=dtype).element_size()
61+
total_byters = numl * dtype_size
62+
63+
padded_bytes = total_byters + alignment
64+
storage = torch.ByteTensor(padded_bytes).to(device)
65+
66+
ptr = storage.data_ptr()
67+
offset = ptr % alignment
68+
if offset != 0:
69+
aligned_ptr = ptr + (alignment - offset)
70+
else:
71+
aligned_ptr = ptr
72+
73+
aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype)
74+
tensor = aligned_storage[:numl].view(shape)
75+
tensor.storage_ref = storage
76+
return tensor
77+
78+
5179
def make_buffers(
5280
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
5381
):
5482
hashes = [secrets.token_hex(16) for _ in range(block_number)]
5583
kv_caches = {}
5684
for i in range(block_layer):
57-
kv_caches[i] = torch.rand(
85+
kv_caches[i] = make_aligned_tensor(
5886
[kv, block_number, block_len, num_head, head_dim],
59-
dtype=torch.bfloat16,
87+
dtype=torch.float16,
6088
device=f"cuda:{device_id}",
6189
)
6290
return hashes, kv_caches
@@ -69,6 +97,14 @@ def store_all_hashes(hashes: List[str]):
6997
f.write(h + "\n")
7098

7199

100+
def load_hashes_from_file() -> List[str]:
101+
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
102+
if not os.path.exists(file_path):
103+
return []
104+
with open(file_path, "r", encoding="utf-8") as f:
105+
return [line.strip() for line in f.readlines()]
106+
107+
72108
def embed(
73109
store: UcmKVStoreBase,
74110
hashes: List[str],
@@ -177,6 +213,8 @@ def run(
177213
block_elem_size: int,
178214
kv: int,
179215
mla: bool,
216+
transferIoDirect: bool,
217+
operation_mode: str = "both", # "write_only", "read_only", or "both"
180218
) -> Tuple[float, float, float, float, float, float]:
181219
"""
182220
Run a single test with given parameters and return performance metrics.
@@ -196,87 +234,99 @@ def run(
196234
w_size_sum, r_size_sum = 0.0, 0.0
197235

198236
store = setup(
199-
storage_backends, block_size, device_id, io_size, transferStreamNumber
237+
storage_backends,
238+
block_size,
239+
device_id,
240+
io_size,
241+
transferStreamNumber,
242+
transferIoDirect,
200243
)
244+
201245
for r in range(repeat):
202246
print(f"\n--- Round {r+1} ---")
203247

204-
hashes, kvcaches = make_buffers(
205-
real_blocks,
206-
device_id,
207-
batch_size,
208-
head_size,
209-
block_len,
210-
block_layer,
211-
num_head,
212-
kv,
213-
)
214-
215-
results = store.create(hashes[:batch_size])
216-
assert sum(results) == 0, "Create operation failed"
217-
218-
w_size, w_time, w_bw = embed(
219-
store,
220-
hashes[:batch_size],
221-
kvcaches,
222-
mla,
223-
)
224-
store.commit(hashes[:batch_size], True)
225-
226-
store_all_hashes(hashes[:batch_size])
227-
228-
r_size, r_time, r_bw = fetch(
229-
store,
230-
hashes[:batch_size],
231-
kvcaches,
232-
mla,
233-
)
234-
235-
w_bw_list.append(w_bw)
236-
r_bw_list.append(r_bw)
237-
w_time_list.append(w_time)
238-
r_time_list.append(r_time)
239-
w_size_sum += w_size
240-
r_size_sum += r_size
241-
242-
# Clean up resources
243-
del kvcaches, hashes
244-
torch.cuda.empty_cache()
248+
if operation_mode in ["write_only", "both"]:
249+
hashes, kvcaches = make_buffers(
250+
real_blocks,
251+
device_id,
252+
batch_size,
253+
head_size,
254+
block_len,
255+
block_layer,
256+
num_head,
257+
kv,
258+
)
259+
260+
results = store.create(hashes[:batch_size])
261+
assert sum(results) == 0, "Create operation failed"
262+
263+
w_size, w_time, w_bw = embed(
264+
store,
265+
hashes[:batch_size],
266+
kvcaches,
267+
mla,
268+
)
269+
store.commit(hashes[:batch_size], True)
270+
271+
if r == 0:
272+
store_all_hashes(hashes[:batch_size])
273+
274+
w_bw_list.append(w_bw)
275+
w_time_list.append(w_time)
276+
w_size_sum += w_size
277+
278+
if operation_mode == "write_only":
279+
del kvcaches, hashes
280+
torch.cuda.empty_cache()
281+
282+
if operation_mode in ["read_only", "both"]:
283+
if operation_mode == "read_only":
284+
saved_hashes = load_hashes_from_file()
285+
if not saved_hashes:
286+
raise RuntimeError("No saved hashes found for read operation")
287+
288+
_, kvcaches = make_buffers(
289+
real_blocks,
290+
device_id,
291+
batch_size,
292+
head_size,
293+
block_len,
294+
block_layer,
295+
num_head,
296+
kv,
297+
)
298+
299+
r_size, r_time, r_bw = fetch(
300+
store,
301+
saved_hashes[:batch_size],
302+
kvcaches,
303+
mla,
304+
)
305+
else:
306+
r_size, r_time, r_bw = fetch(
307+
store,
308+
hashes[:batch_size],
309+
kvcaches,
310+
mla,
311+
)
312+
313+
r_bw_list.append(r_bw)
314+
r_time_list.append(r_time)
315+
r_size_sum += r_size
316+
317+
if operation_mode == "read_only":
318+
del kvcaches
319+
torch.cuda.empty_cache()
320+
else:
321+
del kvcaches, hashes
322+
torch.cuda.empty_cache()
245323

246324
del store
247-
avg_w_bw = sum(w_bw_list) / repeat
248-
avg_r_bw = sum(r_bw_list) / repeat
249-
avg_w_time = sum(w_time_list) / repeat
250-
avg_r_time = sum(r_time_list) / repeat
251-
avg_w_size = w_size_sum / (1024**3) / repeat
252-
avg_r_size = r_size_sum / (1024**3) / repeat
325+
avg_w_bw = sum(w_bw_list) / len(w_bw_list) if w_bw_list else 0.0
326+
avg_r_bw = sum(r_bw_list) / len(r_bw_list) if r_bw_list else 0.0
327+
avg_w_time = sum(w_time_list) / len(w_time_list) if w_time_list else 0.0
328+
avg_r_time = sum(r_time_list) / len(r_time_list) if r_time_list else 0.0
329+
avg_w_size = w_size_sum / (1024**3) / len(w_time_list) if w_time_list else 0.0
330+
avg_r_size = r_size_sum / (1024**3) / len(r_time_list) if r_time_list else 0.0
253331

254332
return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size
255-
256-
257-
if __name__ == "__main__":
258-
os.environ["UC_LOGGER_LEVEL"] = "debug"
259-
260-
try:
261-
result = run(
262-
storage_backends="/home/nfs/zht_data",
263-
device_id=1,
264-
repeat=1,
265-
num_head=1,
266-
block_len=128,
267-
transferStreamNumber=32,
268-
num_tokens=4096,
269-
block_layer=61,
270-
head_size=576,
271-
block_elem_size=2,
272-
kv=1,
273-
mla=True,
274-
)
275-
276-
avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size = result
277-
278-
except Exception as e:
279-
print(f"Error: {e}")
280-
import traceback
281-
282-
traceback.print_exc()

0 commit comments

Comments
 (0)