Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/infinilm/llm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def try_free_blocks(self, num_required: int) -> bool:
def get_num_free_blocks(self) -> int:
return len(self.free_block_ids)

def get_total_usable_blocks(self) -> int:
freeable_used_blocks = sum(
1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
)
return len(self.free_block_ids) + freeable_used_blocks

def __repr__(self):
return (
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
Expand Down
91 changes: 86 additions & 5 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,63 @@ def _update_requests(
req.generated_token_ids.append(token_id)
if req.is_prefill:
req.is_prefill = False
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)

finished_now = False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text

if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
finished_now = True

# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[:-len(stop_str)]
req.generated_text = decoded_text
break

holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)

token_text = self.tokenizer.decode(token_id)
req.generated_text += token_text

if self._check_request_finished(req, token_id):
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
token_text = ""
else:
last_len = getattr(req, "_stream_last_yielded_length", 0)
token_text = decoded_text[last_len:]
if token_text:
req._stream_last_yielded_length = len(decoded_text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

处理尾部可能存在的不完整字符(max_tokens 导致的finished)


# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[:-len(stop_str)]
break

# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
Expand Down Expand Up @@ -283,12 +334,15 @@ def apply_chat_template(
self,
messages: List[dict],
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> str:
"""Apply chat template to messages."""
chat_template_kwargs = chat_template_kwargs or {}
return self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=add_generation_prompt,
tokenize=False,
**chat_template_kwargs,
)


Expand Down Expand Up @@ -486,6 +540,10 @@ def __init__(

self._running = False
self._step_thread: Optional[threading.Thread] = None
self._healthy = True

def is_healthy(self) -> bool:
return bool(self._healthy)

def start(self):
"""Start the background inference loop."""
Expand Down Expand Up @@ -520,6 +578,7 @@ def _step_loop(self):
time.sleep(0.01)
except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True)
self._healthy = False
self._running = False
break

Expand Down Expand Up @@ -581,6 +640,8 @@ def add_chat_request(
request_id: Optional[str] = None,
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> InferenceRequest:
"""Add a chat request to the engine.

Expand All @@ -594,7 +655,11 @@ def add_chat_request(
Returns:
The created InferenceRequest object.
"""
prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True)
prompt = self.engine.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=chat_template_kwargs,
)
return self.add_request(
prompt=prompt,
sampling_params=sampling_params,
Expand All @@ -607,6 +672,7 @@ async def stream_request(
self,
request: InferenceRequest,
timeout: float = 100.0,
request_timeout: Optional[float] = None,
) -> AsyncIterator[TokenOutput]:
"""Stream tokens from a request.

Expand All @@ -619,6 +685,7 @@ async def stream_request(
"""
import asyncio

start = time.time()
while True:
if request.is_finished() and request.output_queue.async_q.empty():
break
Expand All @@ -635,6 +702,20 @@ async def stream_request(
if token_output.finished:
break
except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
if request.is_finished():
break
continue
Expand Down
4 changes: 4 additions & 0 deletions python/infinilm/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def __init__(
# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None

# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
self._stream_last_yielded_length: int = 0

@property
def output_queue(self) -> janus.Queue:
"""Lazy initialization of output queue."""
Expand Down
2 changes: 1 addition & 1 deletion python/infinilm/llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class SamplingParams:
top_k: int = 1
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
stop_token_ids: Optional[List[int]] = None # Placeholder for future usage, not currently handled

def __post_init__(self):
if self.stop is None:
Expand Down
42 changes: 40 additions & 2 deletions python/infinilm/llm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,21 @@ def schedule(self) -> Optional[SchedulerOutput]:
except queue.Empty:
break

if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req)
break

# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue

req_tokens = req.get_input_tokens()
num_required_blocks = req.get_num_blocks_required(self.block_size)

if not self.cache_manager.can_allocate(num_required_blocks):
if not self.cache_manager.try_free_blocks(num_required_blocks):
raise RuntimeError("No available cache blocks")
raise RuntimeError("No available cache blocks for new request")

# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
Expand All @@ -185,6 +194,10 @@ def schedule(self) -> Optional[SchedulerOutput]:
req = self.running_queue.sync_q.get_nowait()
except queue.Empty:
break
# Skip requests that were already finished (e.g., timed out/canceled while running)
if req.is_finished():
self.complete_requests([req])
continue

# Decode phase: allocate slot for newly generated token
try:
Expand All @@ -197,7 +210,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
scheduled_requests.append(req)

except RuntimeError as e:
raise RuntimeError("No available cache blocks") from e
raise RuntimeError("No available cache blocks for new token") from e

# Return decode batch if any running requests were scheduled
if scheduled_requests:
Expand Down Expand Up @@ -237,6 +250,31 @@ def complete_requests(self, requests: List[InferenceRequest]):
# Still running, put back in running queue
self.running_queue.sync_q.put(req)

def can_accept_request(self, request: InferenceRequest) -> bool:
total_required_blocks = 0

# Calculate blocks needed for running requests
running_queue_size = self.running_queue.sync_q.qsize()
for _ in range(running_queue_size):
req = self.running_queue.sync_q.get()
remaining_tokens = (
req.sampling_params.max_tokens - req.get_num_generated_tokens()
)
num_blocks_needed = (
remaining_tokens + self.block_size - 1
) // self.block_size
total_required_blocks += num_blocks_needed
self.running_queue.sync_q.put(req)

# Calculate blocks needed for the new request
total_length = request.get_prompt_length()
total_length += request.sampling_params.max_tokens
num_blocks_needed = (total_length + self.block_size - 1) // self.block_size
total_required_blocks += num_blocks_needed

# Compare with total usable blocks in cache manager
return total_required_blocks <= self.cache_manager.get_total_usable_blocks()

def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
Expand Down
Loading