Skip to content

Conversation

@JiriesKaileh
Copy link
Collaborator

@JiriesKaileh JiriesKaileh commented Nov 13, 2025

Summary:
This PR introduces the JAX implementation for the text portion of the Llama-Guard-4-12B model and configures its end-to-end continuous integration (CI) pipeline. This implementation successfully parity with the PyTorch/TorchAX baseline.

Relevant bugs:

b/439655882

Performance Testing and Verification

The JAX implementation now runs consistently faster than the PyTorch baseline, achieving superior throughput and TTFT latency.

Metric JAX Run (Final) TorchAX Run (Baseline) Analysis
Output Throughput (tok/s) 573.00 563.03 JAX is 1.02x Faster. Achieved performance superiority.
Mean TTFT (ms) 1,429.90 1,480.30 JAX is Faster. Initial latency is consistently lower.
Mean TPOT (ms) 89.18 88.49 Parity. Token generation speed is aligned between the two backends.
Benchmark Duration (s) 2.62 2.66 Faster. Total execution time is shorter.
Accuracy 31.43% 31.43% Parity. Accuracies are the same for the same subset of prompts.

See meta-llama_Llama-Guard-4-12B.yml for unit, integration, and performance tests

…output of the tokenizer.encode() call in the inference script
…_template, removing the need for the .jinja file
…tances from Llama Guard 4 subclass initializations to comply with unit tests
…ile and replaced the model_loader registry key name for Llama Guard 4 with a recognized text only model
@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@jrplatin
Copy link
Collaborator

Please update the title + add accuracy / performance testing in the description too

@JiriesKaileh JiriesKaileh changed the title Jiries/llama guard 4 text feat(llama_guard): Add JAX Llama-Guard-4-12B Text Portion and Achieve Torchax Performance/Accuracy Parity Nov 14, 2025
@jrplatin jrplatin changed the title feat(llama_guard): Add JAX Llama-Guard-4-12B Text Portion and Achieve Torchax Performance/Accuracy Parity [Llama4 Guard]: Add JAX Llama-Guard-4-12B Text Portion Nov 14, 2025
@jrplatin jrplatin changed the title [Llama4 Guard]: Add JAX Llama-Guard-4-12B Text Portion [Llama4 Guard] Add JAX Llama-Guard-4-12B Text Portion Nov 14, 2025
…verride in CI scripts, and simplified prompt formatting in offline inference script
…nference script, and made minor change to CI scripts to prevent breaking CI
…scripts. Still need to resolve dataset origin issue and modify buildkite yml to reflect changes
Signed-off-by: JiriesKaileh <jiries@google.com>
@JiriesKaileh JiriesKaileh force-pushed the jiries/llama-guard-4-text branch from 6fbe2ba to c57c80a Compare November 20, 2025 03:24
Signed-off-by: JiriesKaileh <jiries@google.com>
@JiriesKaileh JiriesKaileh merged commit 3a9e2d4 into main Nov 20, 2025
2 of 3 checks passed
echo -e "\n--- Running Accuracy Check (Mode: ACCURACY) ---"

CONFTEST_DIR="/workspace/tpu-inference/scripts/vllm/integration"
CONFTEST_DIR="/mnt/disks/jiries-disk_data/tpu-inference/scripts/vllm/integration"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JiriesKaileh, is this for your local dev environment? I don't think this works for the agents that are running the CI: link

Copy link
Collaborator Author

@JiriesKaileh JiriesKaileh Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, line 136 should be removed. Let me open a PR for that. Thank you for point this out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants