Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions bench/small_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import os

from transformers import AutoTokenizer
from ssd import LLM, SamplingParams

if __name__ == '__main__':

llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6'
llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b'
eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd'
# eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B'
assert os.path.isdir(llama_1b_path)
assert os.path.isdir(llama_70b_path)
assert os.path.isdir(eagle_path)

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=llama_1b_path)
parser.add_argument("--draft", type=str, default=llama_1b_path)
parser.add_argument("--eagle", action="store_true")
parser.add_argument("--k", type=int, default=7)
parser.add_argument("--jit-speculate", action="store_true")
parser.add_argument("--num-gpus", type=int, default=2)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--chat-template", action="store_true")
parser.add_argument("--communicate-logits", action="store_true")
parser.add_argument("--communicate-cache-hits", action="store_true")
parser.add_argument("--mary", action="store_true")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()

if args.eagle:
args.draft = eagle_path
args.model = llama_70b_path
args.num_gpus = 5
args.jit_speculate = True
args.chat_template = True

llm = LLM(
model=args.model,
draft=args.draft,
use_eagle=args.eagle,
speculate_k=args.k,
speculate=True,
draft_async=True,
num_gpus=args.num_gpus,
jit_speculate=args.jit_speculate,
verbose=args.verbose,
communicate_logits=args.communicate_logits,
communicate_cache_hits=args.communicate_cache_hits,
)
sampling_params = [SamplingParams(temperature=0.0, max_new_tokens=64, ignore_eos=args.ignore_eos)]

if args.mary:
text = "Can you please tell me the lyrics to Mary had a little lamb, and can you repeat it 10 times?"
else:
text = "What is the capital city of France?"
if args.chat_template:
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokens = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
)
token_str = tokenizer.decode(tokens)
print(f"Generating response to prompt: '{token_str}'")
print(f"=============================================================")
outputs, _ = llm.generate([tokens], sampling_params)

else:
outputs, _ = llm.generate([text], sampling_params)

print(outputs[0]["text"])
24 changes: 9 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@ readme = "README.md"
description = "Async tree-based speculative decoding research engine"
requires-python = ">=3.11,<3.13"
dependencies = [
"torch==2.8.0",
"triton==3.4.0",
"torch==2.9.1",
"triton",
"transformers==4.57.1",
"xxhash==3.5.0",
"numpy==2.3.3",
"safetensors==0.6.2",
"tqdm==4.67.1",
"flashinfer-python==0.5.2",
"sgl-kernel==0.3.17.post1",
"nvidia-cutlass-dsl==4.2.1",
"xxhash",
"numpy",
"safetensors",
"tqdm",
"flashinfer-python==0.6.6",
"sgl-kernel==0.3.21",
"nvidia-cutlass-dsl>=4.3.4",
"wandb==0.22.0",
"hf_transfer",
"tiktoken",
]

[project.optional-dependencies]
scripts = [
"datasets",
"huggingface_hub",
]

[project.urls]
Homepage="https://github.com/tanishqkumar/ssd"

Expand Down
4 changes: 3 additions & 1 deletion ssd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@
prepare_decode_tensors_from_seqs,
prepare_block_tables_from_seqs,
prepare_prefill_tensors_from_seqs,
prepare_prefill_payload,
PrefillRequest,
SpeculationRequest,
SpeculationResponse,
)
50 changes: 36 additions & 14 deletions ssd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import torch
from ssd.paths import DEFAULT_TARGET, DEFAULT_DRAFT


@dataclass
class Config:
model: str = DEFAULT_TARGET
max_num_batched_tokens: int = 16384
max_num_seqs: int = 1
max_model_len: int = 4096
max_num_seqs: int = 1
max_model_len: int = 4096
gpu_memory_utilization: float = 0.7
num_gpus: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 256
kvcache_block_size: int = 1
num_kvcache_blocks: int = -1
device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand All @@ -25,13 +26,17 @@ class Config:
draft: str = DEFAULT_DRAFT
speculate_k: int = 1
draft_async: bool = False

# async spec only
async_fan_out: int = 3
fan_out_list: list[int] | None = None
fan_out_list_miss: list[int] | None = None
sampler_x: float | None = None
jit_speculate: bool = False
jit_speculate: bool = False
async_nccl_port: int | None = None
async_nccl_host: str = "127.0.0.1"
communicate_logits: bool = False
communicate_cache_hits: bool = False

# eagle3
use_eagle: bool = False
Expand All @@ -49,26 +54,35 @@ def max_blocks(self):
return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size

def __post_init__(self):
model = self.model
model = self.model
assert os.path.isdir(model)

assert 1 <= self.num_gpus <= 8 # this codebase only works on one node
self.hf_config = AutoConfig.from_pretrained(model)
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings)
if self.speculate:

if not self.speculate:
if self.max_model_len:
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings)
else:
self.max_model_len = self.hf_config.max_position_embeddings
else:
draft = self.draft
self.draft_hf_config = AutoConfig.from_pretrained(draft)
self.max_model_len = min(
self.max_model_len, self.draft_hf_config.max_position_embeddings)
if self.max_model_len:
self.max_model_len = min(
self.max_model_len, self.draft_hf_config.max_position_embeddings)
else:
self.max_model_len = self.draft_hf_config.max_position_embeddings

if self.draft_async:
if self.fan_out_list is None:
self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1)
self.MQ_LEN = sum(self.fan_out_list)
if self.fan_out_list_miss is None:
self.fan_out_list_miss = self.fan_out_list
assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list"

if self.use_eagle:
if self.eagle_layers is None:
L = self.hf_config.num_hidden_layers
Expand All @@ -90,5 +104,13 @@ def __post_init__(self):
if target_max_pos != draft_max_pos:
print(f'[Config] Overriding eagle draft max_position_embeddings: {draft_max_pos} -> {target_max_pos}', flush=True)
self.draft_hf_config.max_position_embeddings = target_max_pos

assert self.max_num_batched_tokens >= self.max_model_len

if self.sampler_x is not None and not self.communicate_cache_hits:
self.communicate_cache_hits = True
print(f'[Config] Setting communicate_cache_hits to True because sampler_x is not None', flush=True)

# assert self.max_num_batched_tokens >= self.max_model_len
if self.max_num_batched_tokens < self.max_model_len:
print(f'[Config] Warning: max_num_batched_tokens ({self.max_num_batched_tokens}) is less than max_model_len ({self.max_model_len})', flush=True)
print(f'[Config] Setting max_num_batched_tokens to max_model_len', flush=True)
self.max_num_batched_tokens = self.max_model_len
5 changes: 5 additions & 0 deletions ssd/engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def _deallocate_n_blocks(self, block_ids: list[int]): # we need to separate wher

def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0

if self.blocks[block_id].hash != -1: # if block was finalized, remove from hash_to_block_id checkme
if self.hash_to_block_id.get(self.blocks[block_id].hash) == block_id:
del self.hash_to_block_id[self.blocks[block_id].hash]

self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)

Expand Down
Loading