-
Notifications
You must be signed in to change notification settings - Fork 46
[Llama4 Guard] Add JAX Llama-Guard-4-12B Text Portion #1090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…output of the tokenizer.encode() call in the inference script
…. No longer have any changes in vllm
…_template, removing the need for the .jinja file
…tances from Llama Guard 4 subclass initializations to comply with unit tests
…ertion error from model downsizing
…into jiries/llama-guard-4-text
…ile and replaced the model_loader registry key name for Llama Guard 4 with a recognized text only model
…into jiries/llama-guard-4-text
…into jiries/llama-guard-4-text
…into jiries/llama-guard-4-text
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
|
Please update the title + add accuracy / performance testing in the description too |
…atement in TPUModelRunner.load_model()
…verride in CI scripts, and simplified prompt formatting in offline inference script
…nference script, and made minor change to CI scripts to prevent breaking CI
…into jiries/llama-guard-4-text
…scripts. Still need to resolve dataset origin issue and modify buildkite yml to reflect changes
… new perf and accuracy script changes
…nted error for multimodal inputs
…into jiries/llama-guard-4-text
Signed-off-by: JiriesKaileh <jiries@google.com>
6fbe2ba to
c57c80a
Compare
…into jiries/llama-guard-4-text
Signed-off-by: JiriesKaileh <jiries@google.com>
| echo -e "\n--- Running Accuracy Check (Mode: ACCURACY) ---" | ||
|
|
||
| CONFTEST_DIR="/workspace/tpu-inference/scripts/vllm/integration" | ||
| CONFTEST_DIR="/mnt/disks/jiries-disk_data/tpu-inference/scripts/vllm/integration" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @JiriesKaileh, is this for your local dev environment? I don't think this works for the agents that are running the CI: link
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, line 136 should be removed. Let me open a PR for that. Thank you for point this out.
Summary:
This PR introduces the JAX implementation for the text portion of the Llama-Guard-4-12B model and configures its end-to-end continuous integration (CI) pipeline. This implementation successfully parity with the PyTorch/TorchAX baseline.
Relevant bugs:
b/439655882
Performance Testing and Verification
The JAX implementation now runs consistently faster than the PyTorch baseline, achieving superior throughput and TTFT latency.
See
meta-llama_Llama-Guard-4-12B.ymlfor unit, integration, and performance tests