Skip to content

Commit 02adac3

Browse files
Merge branch 'main' of https://github.com/open-sciencelab/GraphGen into main
2 parents d60238e + 9fcad68 commit 02adac3

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

graphgen/common/init_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def create_llm(
131131
ray.get_actor(actor_name)
132132
except ValueError:
133133
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134-
num_gpus = int(config.pop("num_gpus", 0))
134+
num_gpus = float(config.pop("num_gpus", 0))
135135
actor = (
136136
ray.remote(LLMServiceActor)
137137
.options(

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(
3333

3434
engine_args = AsyncEngineArgs(
3535
model=model,
36-
tensor_parallel_size=tensor_parallel_size,
37-
gpu_memory_utilization=gpu_memory_utilization,
36+
tensor_parallel_size=int(tensor_parallel_size),
37+
gpu_memory_utilization=float(gpu_memory_utilization),
3838
trust_remote_code=kwargs.get("trust_remote_code", True),
3939
disable_log_stats=False,
4040
)
@@ -82,15 +82,15 @@ async def generate_answer(
8282

8383
async def generate_topk_per_token(
8484
self, text: str, history: Optional[List[str]] = None, **extra: Any
85-
) -> List[Token]:
85+
) -> List[Token]:
8686
full_prompt = self._build_inputs(text, history)
87-
8887
request_id = f"graphgen_topk_{uuid.uuid4()}"
8988

9089
sp = self.SamplingParams(
9190
temperature=0,
9291
max_tokens=1,
9392
logprobs=self.topk,
93+
prompt_logprobs=1,
9494
)
9595

9696
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
@@ -108,14 +108,22 @@ async def generate_topk_per_token(
108108

109109
top_logprobs = final_output.outputs[0].logprobs[0]
110110

111-
tokens = []
111+
candidate_tokens = []
112112
for _, logprob_obj in top_logprobs.items():
113-
tok_str = logprob_obj.decoded_token
113+
tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
114114
prob = float(math.exp(logprob_obj.logprob))
115-
tokens.append(Token(tok_str, prob))
116-
117-
tokens.sort(key=lambda x: -x.prob)
118-
return tokens
115+
candidate_tokens.append(Token(tok_str, prob))
116+
117+
candidate_tokens.sort(key=lambda x: -x.prob)
118+
119+
if candidate_tokens:
120+
main_token = Token(
121+
text=candidate_tokens[0].text,
122+
prob=candidate_tokens[0].prob,
123+
top_candidates=candidate_tokens
124+
)
125+
return [main_token]
126+
return []
119127

120128
async def generate_inputs_prob(
121129
self, text: str, history: Optional[List[str]] = None, **extra: Any

0 commit comments

Comments
 (0)