2323# SOFTWARE.
2424#
2525import csv
26+ import math
2627import os
2728import secrets
2829import time
3536
3637
3738def setup (
38- storage_backends , block_size , device_id , io_size , transferStreamNumber
39+ storage_backends ,
40+ block_size ,
41+ device_id ,
42+ io_size ,
43+ transferStreamNumber ,
44+ transferIoDirect ,
3945) -> UcmKVStoreBase :
4046 config = {
4147 "storage_backends" : storage_backends ,
@@ -44,19 +50,41 @@ def setup(
4450 "device" : device_id ,
4551 "io_size" : io_size ,
4652 "transferStreamNumber" : transferStreamNumber ,
53+ "transferIoDirect" : transferIoDirect ,
4754 }
4855 return UcmNfsStore (config )
4956
5057
58+ def make_aligned_tensor (shape , dtype , device , alignment = 4096 ):
59+ numl = math .prod (shape )
60+ dtype_size = torch .tensor (1 , dtype = dtype ).element_size ()
61+ total_byters = numl * dtype_size
62+
63+ padded_bytes = total_byters + alignment
64+ storage = torch .ByteTensor (padded_bytes ).to (device )
65+
66+ ptr = storage .data_ptr ()
67+ offset = ptr % alignment
68+ if offset != 0 :
69+ aligned_ptr = ptr + (alignment - offset )
70+ else :
71+ aligned_ptr = ptr
72+
73+ aligned_storage = storage [(aligned_ptr - ptr ) :].view (dtype )
74+ tensor = aligned_storage [:numl ].view (shape )
75+ tensor .storage_ref = storage
76+ return tensor
77+
78+
5179def make_buffers (
5280 block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head , kv
5381):
5482 hashes = [secrets .token_hex (16 ) for _ in range (block_number )]
5583 kv_caches = {}
5684 for i in range (block_layer ):
57- kv_caches [i ] = torch . rand (
85+ kv_caches [i ] = make_aligned_tensor (
5886 [kv , block_number , block_len , num_head , head_dim ],
59- dtype = torch .bfloat16 ,
87+ dtype = torch .float16 ,
6088 device = f"cuda:{ device_id } " ,
6189 )
6290 return hashes , kv_caches
@@ -69,6 +97,14 @@ def store_all_hashes(hashes: List[str]):
6997 f .write (h + "\n " )
7098
7199
100+ def load_hashes_from_file () -> List [str ]:
101+ file_path = os .path .join (os .path .dirname (__file__ ), "kvcache_block_hashes.txt" )
102+ if not os .path .exists (file_path ):
103+ return []
104+ with open (file_path , "r" , encoding = "utf-8" ) as f :
105+ return [line .strip () for line in f .readlines ()]
106+
107+
72108def embed (
73109 store : UcmKVStoreBase ,
74110 hashes : List [str ],
@@ -177,6 +213,8 @@ def run(
177213 block_elem_size : int ,
178214 kv : int ,
179215 mla : bool ,
216+ transferIoDirect : bool ,
217+ operation_mode : str = "both" , # "write_only", "read_only", or "both"
180218) -> Tuple [float , float , float , float , float , float ]:
181219 """
182220 Run a single test with given parameters and return performance metrics.
@@ -196,87 +234,99 @@ def run(
196234 w_size_sum , r_size_sum = 0.0 , 0.0
197235
198236 store = setup (
199- storage_backends , block_size , device_id , io_size , transferStreamNumber
237+ storage_backends ,
238+ block_size ,
239+ device_id ,
240+ io_size ,
241+ transferStreamNumber ,
242+ transferIoDirect ,
200243 )
244+
201245 for r in range (repeat ):
202246 print (f"\n --- Round { r + 1 } ---" )
203247
204- hashes , kvcaches = make_buffers (
205- real_blocks ,
206- device_id ,
207- batch_size ,
208- head_size ,
209- block_len ,
210- block_layer ,
211- num_head ,
212- kv ,
213- )
214-
215- results = store .create (hashes [:batch_size ])
216- assert sum (results ) == 0 , "Create operation failed"
217-
218- w_size , w_time , w_bw = embed (
219- store ,
220- hashes [:batch_size ],
221- kvcaches ,
222- mla ,
223- )
224- store .commit (hashes [:batch_size ], True )
225-
226- store_all_hashes (hashes [:batch_size ])
227-
228- r_size , r_time , r_bw = fetch (
229- store ,
230- hashes [:batch_size ],
231- kvcaches ,
232- mla ,
233- )
234-
235- w_bw_list .append (w_bw )
236- r_bw_list .append (r_bw )
237- w_time_list .append (w_time )
238- r_time_list .append (r_time )
239- w_size_sum += w_size
240- r_size_sum += r_size
241-
242- # Clean up resources
243- del kvcaches , hashes
244- torch .cuda .empty_cache ()
248+ if operation_mode in ["write_only" , "both" ]:
249+ hashes , kvcaches = make_buffers (
250+ real_blocks ,
251+ device_id ,
252+ batch_size ,
253+ head_size ,
254+ block_len ,
255+ block_layer ,
256+ num_head ,
257+ kv ,
258+ )
259+
260+ results = store .create (hashes [:batch_size ])
261+ assert sum (results ) == 0 , "Create operation failed"
262+
263+ w_size , w_time , w_bw = embed (
264+ store ,
265+ hashes [:batch_size ],
266+ kvcaches ,
267+ mla ,
268+ )
269+ store .commit (hashes [:batch_size ], True )
270+
271+ if r == 0 :
272+ store_all_hashes (hashes [:batch_size ])
273+
274+ w_bw_list .append (w_bw )
275+ w_time_list .append (w_time )
276+ w_size_sum += w_size
277+
278+ if operation_mode == "write_only" :
279+ del kvcaches , hashes
280+ torch .cuda .empty_cache ()
281+
282+ if operation_mode in ["read_only" , "both" ]:
283+ if operation_mode == "read_only" :
284+ saved_hashes = load_hashes_from_file ()
285+ if not saved_hashes :
286+ raise RuntimeError ("No saved hashes found for read operation" )
287+
288+ _ , kvcaches = make_buffers (
289+ real_blocks ,
290+ device_id ,
291+ batch_size ,
292+ head_size ,
293+ block_len ,
294+ block_layer ,
295+ num_head ,
296+ kv ,
297+ )
298+
299+ r_size , r_time , r_bw = fetch (
300+ store ,
301+ saved_hashes [:batch_size ],
302+ kvcaches ,
303+ mla ,
304+ )
305+ else :
306+ r_size , r_time , r_bw = fetch (
307+ store ,
308+ hashes [:batch_size ],
309+ kvcaches ,
310+ mla ,
311+ )
312+
313+ r_bw_list .append (r_bw )
314+ r_time_list .append (r_time )
315+ r_size_sum += r_size
316+
317+ if operation_mode == "read_only" :
318+ del kvcaches
319+ torch .cuda .empty_cache ()
320+ else :
321+ del kvcaches , hashes
322+ torch .cuda .empty_cache ()
245323
246324 del store
247- avg_w_bw = sum (w_bw_list ) / repeat
248- avg_r_bw = sum (r_bw_list ) / repeat
249- avg_w_time = sum (w_time_list ) / repeat
250- avg_r_time = sum (r_time_list ) / repeat
251- avg_w_size = w_size_sum / (1024 ** 3 ) / repeat
252- avg_r_size = r_size_sum / (1024 ** 3 ) / repeat
325+ avg_w_bw = sum (w_bw_list ) / len ( w_bw_list ) if w_bw_list else 0.0
326+ avg_r_bw = sum (r_bw_list ) / len ( r_bw_list ) if r_bw_list else 0.0
327+ avg_w_time = sum (w_time_list ) / len ( w_time_list ) if w_time_list else 0.0
328+ avg_r_time = sum (r_time_list ) / len ( r_time_list ) if r_time_list else 0.0
329+ avg_w_size = w_size_sum / (1024 ** 3 ) / len ( w_time_list ) if w_time_list else 0.0
330+ avg_r_size = r_size_sum / (1024 ** 3 ) / len ( r_time_list ) if r_time_list else 0.0
253331
254332 return avg_w_size , avg_w_time , avg_w_bw , avg_r_time , avg_r_bw , avg_r_size
255-
256-
257- if __name__ == "__main__" :
258- os .environ ["UC_LOGGER_LEVEL" ] = "debug"
259-
260- try :
261- result = run (
262- storage_backends = "/home/nfs/zht_data" ,
263- device_id = 1 ,
264- repeat = 1 ,
265- num_head = 1 ,
266- block_len = 128 ,
267- transferStreamNumber = 32 ,
268- num_tokens = 4096 ,
269- block_layer = 61 ,
270- head_size = 576 ,
271- block_elem_size = 2 ,
272- kv = 1 ,
273- mla = True ,
274- )
275-
276- avg_w_size , avg_w_time , avg_w_bw , avg_r_time , avg_r_bw , avg_r_size = result
277-
278- except Exception as e :
279- print (f"Error: { e } " )
280- import traceback
281-
282- traceback .print_exc ()
0 commit comments