Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tensorrt_llm/llmapi/async_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Optional

from .llm import LLM


# We extend with the AsyncLLM class rather than modifying the LLM class directly
# since AsyncLLM operations depend on Ray, while the LLM class should be generic
# and independent of the orchestrator.
# NOTE: This class is for internal use only, not for external use. It will be
# changed frequently and may not be stable.
class AsyncLLM(LLM):
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
resume operations that are necessary for RL or agentic scenarios.
"""

def __init__(self, *args, **kwargs):
# NOTE: This should be light, the heavy initialization should be done in the async_init_phase.
super().__init__(*args, **kwargs)

# It will inherit the LLM.generate_async method.

# I doubt if we need to support both sync and async modes for each method.
# Maybe only the async mode is enough considering RL developers are more
# likely to use async mode.

async def setup_async(self):
"""Setup the LLM asynchronously."""
pass

def release_async(self):
"""Release the GPU memory used by the LLM asynchronously."""
pass

def resume_async(self):
"""Resume the LLM asynchronously."""

def collective_rpc_call(
self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None,
) -> list[Any]:
"""Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor.

Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None.

Returns:
list[Any]: A list of results from each worker.
"""
pass
30 changes: 2 additions & 28 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,34 +1009,6 @@ def __init__(self,
backend=backend,
**kwargs)

@set_api_status("prototype")
def _collective_rpc(self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
"""
Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor.

Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None.

Returns:
list[Any]: A list of results from each worker.
"""
if hasattr(self._executor, 'collective_rpc'):
return self._executor.collective_rpc(method, args, kwargs,
non_block, unique_reply_rank)
else:
raise ValueError(
f"Executor type {type(self._executor)} does not support collective RPC."
)

def _build_model(self):
super()._build_model()
assert self._engine_dir is None
Expand Down Expand Up @@ -1125,7 +1097,9 @@ def __init__(self,
Parameters:
""" + TORCH_LLM_DOCSTRING


class AsyncLLM(LLM):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down