Skip to content

Commit b6d2545

Browse files
committed
Consolidated Llama Guard 4 Text testing scripts and raised NotImplemented error for multimodal inputs
1 parent 8d4c35a commit b6d2545

File tree

5 files changed

+31
-36
lines changed

5 files changed

+31
-36
lines changed

examples/offline_safety_model_inference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
"""
5-
Example script for running offline safety classification inference on Llama Guard 4.
5+
Example script for running offline safety classification inference on safety models.
66
7-
applies the Llama Guard 4 chat template to 35 prompts from the ailuminate dataset,
7+
Currently supported models:
8+
- Llama Guard 4 (meta-llama/Llama-Guard-4-12B)
9+
10+
applies the safety model's chat template to 35 prompts from the ailuminate dataset,
811
and runs inference using the JAX backend. It calculates the final accuracy based on
912
the model's 'safe'/'unsafe' and S-code classification.
1013
@@ -59,7 +62,6 @@ def create_parser():
5962

6063
# Add sampling params
6164
sampling_group = parser.add_argument_group("Sampling parameters")
62-
# For Llama Guard, we want deterministic output.
6365
sampling_group.add_argument("--max-tokens", type=int, default=128)
6466
sampling_group.add_argument("--temperature", type=float, default=0.0)
6567
sampling_group.add_argument("--top-p", type=float, default=1.0)
@@ -68,7 +70,6 @@ def create_parser():
6870

6971

7072
def main(args: dict):
71-
# Pop arguments not used by LLM
7273
max_tokens = args.pop("max_tokens")
7374
temperature = args.pop("temperature")
7475
top_p = args.pop("top_p")

scripts/vllm/integration/test_safety_model_accuracy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pandas as pd
1616
from tpu_inference.core import disagg_utils
1717

18-
# --- HAZARD MAPPING (REQUIRED FOR LABEL CONVERSION) ---
18+
# --- HAZARD MAPPING ---
1919
# This mapping is based on the MLCommons taxonomy and is needed for the data loader.
2020
# The ailuminate test set is a standard dataset for safety models
2121
HAZARD_MAPPING = {
@@ -109,7 +109,6 @@ def safety_test_data_loader(request):
109109
dataset_path = request.config.getoption("--dataset-path")
110110

111111
if not dataset_path:
112-
# This is expected in CI if the benchmark task definition isn't passing it yet.
113112
pytest.skip("The --dataset-path argument is required for data loading.")
114113

115114
# Load the full dataset (Pandas handles GCS URI)

tests/e2e/benchmarking/safety_model_benchmark.sh

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
set -e
2323

2424
# --- Configuration & Defaults ---
25-
# Variables now rely on being set in the environment (e.g., via export or CI YAML)
2625
MODEL_NAME="${TEST_MODEL}"
2726
TP_SIZE="${TENSOR_PARALLEL_SIZE}"
2827

@@ -96,7 +95,6 @@ else
9695
fi
9796

9897
# Convert to JSONL to be compatible with vllm bench serve command
99-
# TODO: ensure this conversion works
10098
if [ ! -f "$LOCAL_JSONL_FILE" ] || [ "$TEST_MODE" == "performance" ]; then
10199
echo "Converting CSV to JSONL for performance run..."
102100

@@ -131,21 +129,14 @@ fi
131129
run_accuracy_check() {
132130
echo -e "\n--- Running Accuracy Check (Mode: ACCURACY) ---"
133131

134-
# 1. Define the correct execution directory for conftest.py discovery
135-
#CONFTEST_DIR="/workspace/tpu-inference/scripts/vllm/integration"
136-
CONFTEST_DIR="/mnt/disks/jiries-disk_data/tpu-inference/scripts/vllm/integration"
132+
CONFTEST_DIR="/workspace/tpu-inference/scripts/vllm/integration"
137133

138-
# 2. Calculate the relative path from $CONFTEST_DIR to the test file.
139-
# We must go up three levels and then down into the test folder.
140134
RELATIVE_TEST_FILE="test_safety_model_accuracy.py"
141135

142-
# 3. Directory Change and Pytest Execution (in a subshell)
143136
(
144-
# Change to the directory containing conftest.py
145137
cd "$CONFTEST_DIR" || { echo "Error: Failed to find conftest directory: $CONFTEST_DIR"; exit 1; }
146138
echo "Running pytest from: $(pwd)"
147139

148-
# Execute Pytest, running the test file using the relative path
149140
python -m pytest -s -rP "$RELATIVE_TEST_FILE::test_safety_model_accuracy_check" \
150141
--tensor-parallel-size="$TP_SIZE" \
151142
--model-name="$MODEL_NAME" \
@@ -160,7 +151,6 @@ run_accuracy_check() {
160151
run_performance_benchmark() {
161152
echo -e "\n--- Running Performance Benchmark (Mode: PERFORMANCE) ---"
162153

163-
# 1. Benchmark Execution (against the running server)
164154
vllm bench serve \
165155
--model "$MODEL_NAME" \
166156
--endpoint "/v1/completions" \
@@ -171,7 +161,6 @@ run_performance_benchmark() {
171161
--custom-output-len "$OUTPUT_LEN_OVERRIDE" \
172162
2>&1 | tee "$BENCHMARK_LOG_FILE"
173163

174-
# 2. Check throughput metric from the log file
175164
ACTUAL_THROUGHPUT=$(awk '/Output token throughput \(tok\/s\):/ {print $NF}' "$BENCHMARK_LOG_FILE")
176165

177166
if [ -z "$ACTUAL_THROUGHPUT" ]; then
@@ -181,7 +170,6 @@ run_performance_benchmark() {
181170

182171
echo "Actual Output Token Throughput: $ACTUAL_THROUGHPUT tok/s"
183172

184-
# 3. Perform float comparison
185173
if awk -v actual="$ACTUAL_THROUGHPUT" -v target="$TARGET_THROUGHPUT" 'BEGIN { exit !(actual >= target) }'; then
186174
echo "PERFORMANCE CHECK PASSED: $ACTUAL_THROUGHPUT >= $TARGET_THROUGHPUT"
187175
return 0
@@ -196,29 +184,27 @@ run_performance_benchmark() {
196184
# Set initial trap to ensure cleanup happens even on immediate exit
197185
trap 'cleanUp "$MODEL_NAME"' EXIT
198186

199-
# --- 1. RUN TEST MODE (Offline Accuracy) ---
187+
# --- 1. RUN TEST MODE ---
200188
if [ "$TEST_MODE" == "accuracy" ]; then
201189
run_accuracy_check
202190
EXIT_CODE=$?
203-
# Exit immediately after offline test, as server setup is unnecessary
191+
204192
exit $EXIT_CODE
205193
fi
206194

207195
# --- 2. START SERVER (Required ONLY for Performance Mode) ---
208196
if [ "$TEST_MODE" == "performance" ]; then
209197
echo "Spinning up the vLLM server for $MODEL_NAME (TP=$TP_SIZE)..."
210198

211-
# Server startup (NOTE: No SKIP_JAX_PRECOMPILE=1 here)
199+
# Server startup
212200
(vllm serve "$MODEL_NAME" \
213201
--tensor-parallel-size "$TP_SIZE" \
214202
--max-model-len="$MAX_MODEL_LEN" \
215203
--max-num-batched-tokens="$MAX_BATCHED_TOKENS" \
216204
2>&1 | tee -a "$LOG_FILE") &
217205

218-
# WAIT FOR SERVER (Shared Function Call)
219-
waitForServerReady # Exits 1 on timeout
206+
waitForServerReady
220207

221-
# Execute performance test
222208
run_performance_benchmark
223209
EXIT_CODE=$?
224210
fi

tpu_inference/models/jax/llama_guard_4.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
from tpu_inference.logger import init_logger
22

3-
logger = init_logger(__name__)
4-
5-
# --- CRITICAL FIX: Add logger.warning() call here ---
6-
logger.warning(
7-
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
8-
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
9-
"Multimodal inputs will fail.\n"
10-
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨"
11-
)
12-
133
import re
144
from typing import List, Optional, Tuple, Any
155

@@ -35,6 +25,13 @@
3525

3626
logger = init_logger(__name__)
3727

28+
logger.warning(
29+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
30+
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
31+
"Multimodal inputs will fail.\n"
32+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨"
33+
)
34+
3835
class LlamaGuard4ForCausalLM(nnx.Module):
3936

4037
def __init__(self,

tpu_inference/runner/tpu_runner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ def load_model(self):
508508
self.is_multimodal_model = (self.model_config.is_multimodal_model
509509
and self.get_multimodal_embeddings_fn
510510
is not None
511-
and self.model_config.hf_config.architectures[0] != "Llama4ForConditionalGeneration" ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
511+
and hasattr(self.model_config.hf_config, "architectures") #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
512+
and len(self.model_config.hf_config.architectures) >= 1
513+
and self.model_config.hf_config.architectures[0] != "Llama4ForConditionalGeneration" )
512514

513515
logger.info(f"Init model | "
514516
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -695,13 +697,23 @@ def _execute_model(
695697
logits_indices_selector,
696698
) = self._prepare_inputs(scheduler_output)
697699

700+
is_llama_guard_4 = ( hasattr(self.model_config.hf_config, "architectures") #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
701+
and len(self.model_config.hf_config.architectures) >= 1
702+
and self.model_config.hf_config.architectures[0] == "Llama4ForConditionalGeneration" )
703+
698704
# multi-modal support
699705
if self.is_multimodal_model:
700706
# Run the multimodal encoder if any.
701707
# We have the modality embeds at this time.
702708
self.mm_manager.execute_mm_encoder(scheduler_output)
703709
mm_embeds = self.mm_manager.gather_mm_embeddings(
704710
scheduler_output, input_ids.shape[0])
711+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
712+
elif is_llama_guard_4 and any(self.mm_manager.runner.requests[req_id].mm_features for req_id in self.mm_manager.runner.input_batch.req_ids):
713+
raise NotImplementedError(
714+
"Llama Guard 4 (JAX) currently supports only text inputs. "
715+
"Multimodal processing via 'inputs_embeds' is not yet implemented."
716+
)
705717
else:
706718
mm_embeds = []
707719

0 commit comments

Comments
 (0)