@@ -33,8 +33,8 @@ def __init__(
3333
3434 engine_args = AsyncEngineArgs (
3535 model = model ,
36- tensor_parallel_size = tensor_parallel_size ,
37- gpu_memory_utilization = gpu_memory_utilization ,
36+ tensor_parallel_size = int ( tensor_parallel_size ) ,
37+ gpu_memory_utilization = float ( gpu_memory_utilization ) ,
3838 trust_remote_code = kwargs .get ("trust_remote_code" , True ),
3939 disable_log_stats = False ,
4040 )
@@ -82,15 +82,15 @@ async def generate_answer(
8282
8383 async def generate_topk_per_token (
8484 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
85- ) -> List [Token ]:
85+ ) -> List [Token ]:
8686 full_prompt = self ._build_inputs (text , history )
87-
8887 request_id = f"graphgen_topk_{ uuid .uuid4 ()} "
8988
9089 sp = self .SamplingParams (
9190 temperature = 0 ,
9291 max_tokens = 1 ,
9392 logprobs = self .topk ,
93+ prompt_logprobs = 1 ,
9494 )
9595
9696 result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
@@ -108,14 +108,22 @@ async def generate_topk_per_token(
108108
109109 top_logprobs = final_output .outputs [0 ].logprobs [0 ]
110110
111- tokens = []
111+ candidate_tokens = []
112112 for _ , logprob_obj in top_logprobs .items ():
113- tok_str = logprob_obj .decoded_token
113+ tok_str = logprob_obj .decoded_token . strip () if logprob_obj . decoded_token else ""
114114 prob = float (math .exp (logprob_obj .logprob ))
115- tokens .append (Token (tok_str , prob ))
116-
117- tokens .sort (key = lambda x : - x .prob )
118- return tokens
115+ candidate_tokens .append (Token (tok_str , prob ))
116+
117+ candidate_tokens .sort (key = lambda x : - x .prob )
118+
119+ if candidate_tokens :
120+ main_token = Token (
121+ text = candidate_tokens [0 ].text ,
122+ prob = candidate_tokens [0 ].prob ,
123+ top_candidates = candidate_tokens
124+ )
125+ return [main_token ]
126+ return []
119127
120128 async def generate_inputs_prob (
121129 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
0 commit comments