@@ -86,13 +86,15 @@ def generate(
8686 if extra_kwargs is not None :
8787 kwargs .update (extra_kwargs )
8888
89+ is_fp8 = "fp8" in kwargs ["attn_name" ]
8990 if isinstance (input_ids , torch .Tensor ):
9091 if len (input_ids .shape ) == 1 :
9192 input_ids = input_ids .unsqueeze (0 )
9293
9394 is_batch = input_ids .shape [0 ] > 1
94- # our model requires batch dimension
95- if not is_batch :
95+ # our model requires batch dimension when running with fp8
96+ # this is fixed in torch >= 2.8
97+ if is_fp8 and not is_batch :
9698 input_ids , kwargs = adjust_inputs_to_batch (input_ids , ** kwargs )
9799 else :
98100 raise TypeError ("input_ids must be one of Tensor or List" )
@@ -115,7 +117,10 @@ def generate(
115117 # if we set these variables here, we run the risk of warming up and generating with different sizes
116118 _MAX_BATCH = int (os .environ ["VLLM_DT_MAX_BATCH_SIZE" ])
117119 _MAX_CONTEXT_LENGTH = int (os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ])
118- NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH ) // BLOCK_SIZE
120+ # if the user provides a hint to the number of blocks to use, use it directly
121+ NUM_BLOCKS = kwargs .get (
122+ "_kvcache_num_blocks_hint" , (_MAX_BATCH * _MAX_CONTEXT_LENGTH ) // BLOCK_SIZE
123+ )
119124
120125 if hasattr (model , "head" ):
121126 model_dtype = model .head .weight .dtype
@@ -345,7 +350,10 @@ def generate(
345350 [
346351 (
347352 [b_seq [0 ]]
348- * (max (2 , max ([len (b ) for b in block_table ])) - len (b_seq ))
353+ * (
354+ max (2 if is_fp8 else 1 , max ([len (b ) for b in block_table ]))
355+ - len (b_seq )
356+ )
349357 )
350358 + b_seq
351359 for b_seq in block_table
@@ -408,17 +416,19 @@ def generate(
408416 if post_iteration_hook is not None :
409417 _logits = logits
410418 _next_val = next_val
411- # since we cannot handle batch size 1 and mimic with batch size 2, we need to only pass in the first logits/next_val
412- if not is_batch :
419+ # since we cannot handle batch size 1 for fp8 and mimic with batch size 2, we need to only pass in the first logits/next_val
420+ if is_fp8 and not is_batch :
413421 _logits = logits [0 ].unsqueeze (0 )
414422 _next_val = _next_val [0 ].unsqueeze (0 )
415423 _next_val , kwargs = post_iteration_hook (
416424 i + prompt_length , _logits , _next_val , kwargs
417425 )
418426 # we need to normalize back to batch size 2
419- if not is_batch :
427+ if is_fp8 and not is_batch :
420428 # we need to do an in-place copy here for the same reason we do in-place copy for injecting tokens
421429 next_val .copy_ (torch .cat ((_next_val , _next_val ), dim = 0 ))
430+ else :
431+ next_val = _next_val
422432
423433 result = torch .cat ((result , next_val ), dim = - 1 )
424434
@@ -454,7 +464,12 @@ def generate(
454464 return result
455465
456466
457- VLLM_DT_MAX_BATCH_TKV_LIMIT = 131072
467+ # this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct
468+ KVCACHE_NUM_BLOCKS_HINT = int (
469+ os .environ .get ("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT" , 2080 )
470+ )
471+
472+ VLLM_DT_MAX_BATCH_TKV_LIMIT = int (os .environ .get ("VLLM_DT_MAX_BATCH_TKV_LIMIT" , 131072 ))
458473
459474
460475class ProgramCriteria :
@@ -468,7 +483,11 @@ def __init__(
468483 self .tkv_granularity = tkv_granularity
469484
470485 def is_possible (self , batch_size , tkv ):
471- return batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT
486+ return (
487+ (batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT )
488+ and (batch_size <= self .max_batch )
489+ and (tkv <= self .max_tkv )
490+ )
472491
473492 def calculate_padding (self , batch_size , tkv ):
474493 min_batch_req = (
@@ -496,7 +515,12 @@ def __hash__(self):
496515
497516
498517def get_programs_prompts (
499- program_criteria_list , multiple , max_batch_size , max_tkv , program_cycles
518+ program_criteria_list ,
519+ multiple ,
520+ max_batch_size ,
521+ max_tkv ,
522+ program_cycles ,
523+ prioritize_large_batch_sizes = True ,
500524):
501525 program_map = {}
502526
@@ -515,6 +539,11 @@ def get_programs_prompts(
515539 if (
516540 resolved_programs [program_index ] is None
517541 or padding < resolved_programs [program_index ][1 ]
542+ or (
543+ padding == resolved_programs [program_index ][1 ]
544+ and program_criteria .batch_granularity
545+ > resolved_programs [program_index ][0 ].batch_granularity
546+ )
518547 ):
519548 resolved_programs [program_index ] = (
520549 program_criteria ,
@@ -528,4 +557,8 @@ def get_programs_prompts(
528557 else :
529558 program_map [key ] = [(batch_size , prompt_len )]
530559
560+ # give higher priority to larger batches
561+ for _ , v in program_map .items ():
562+ v .sort (key = lambda t : t [0 ], reverse = prioritize_large_batch_sizes )
563+
531564 return program_map
0 commit comments