From 8e99ab777a7b065d40cb861f6a470cd323f6a115 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Tue, 11 Nov 2025 14:39:45 -0500 Subject: [PATCH 1/6] First refactor Signed-off-by: Rafael Vasquez --- .../scripts/drive_paged_programs.py | 174 +++++++++--------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index c10652b..6a69ba1 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -169,6 +169,11 @@ action="store_true", help="set to true to save cpu validation outputs for later consumption", ) +parser.add_argument( + "--stop_after_info_outputs", + action="store_true", + help="set to true to stop after cpu validation outputs have been saved", +) parser.add_argument( "--prioritize_large_batch_sizes", action="store_true", @@ -373,55 +378,6 @@ def __load_validation_info( ) model.eval() -fx_config.backed_size_oblivious = True -model.compile(backend="sendnn", options={"sendnn.dynamic": True}) - -__maybe_prepare_fp8_weights(model, is_fp8) - -if not args.skip_validation: - with stagger_region(args.stagger_load): - validation_model = get_model( - architecture="hf_pretrained", - device_type="cpu", - data_type=None if is_fp8 else torch.float32, - fused_weights=False, - **model_path_kwargs, - **distributed_kwargs, - ) - validation_model.eval() - -# warmup with any input so compiler produces criteria json -# TODO: Swap this with __prepare_inputs once fix for shape_id is available -# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) -prompt_list = [torch.arange(0, 64, dtype=torch.int64)] -# matching vllm warmup to pad to 2 on fp8, and no pad for fp16 -if is_fp8: - prompt_list = prompt_list * 2 -input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) -extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - -extra_kwargs["attn_name"] = ATTN_NAME -if ( - "granite-3.3-8b-instruct" in model_variant - and USE_DISTRIBUTED - and dist.get_world_size() == 4 -): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT -warmup_model( - model, - input_ids, - max_new_tokens=max_new_tokens, - compile_dynamic_sendnn=True, - stagger_update_lazyhandle=args.stagger_update_lazyhandle, - prefill_chunk_size=args.prefill_chunk_size, - **extra_kwargs, -) - -if USE_DISTRIBUTED: - # wait for rank0 to be finished as it is the only one generating the criteria json - # this is needed since otherwise we may run into a race condition - torch.distributed.barrier() - @dataclass class ProgramInfo: @@ -431,7 +387,6 @@ class ProgramInfo: prompt_length_limit: int prompt_length_limit_type: str - def parse_program_limit(limit_str: str) -> tuple[int, str]: matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)") @@ -448,7 +403,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: limit_val = int(match.group(2)) return limit_val, limit_type - +# TODO: Add a check or logic for case that prog criteria json must exist if saving CPU outputs with open(args.program_criteria_json_path, "r") as f: program_criteria_json_list = json.load(f)["programs"] program_criteria_list = [] @@ -621,39 +576,20 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" ) - -# metric calculator based on the cross-entropy and mean diff for each decode step -def __metric_calculator(r: torch.Tensor, t: torch.Tensor): - cross_entropy = torch.nn.CrossEntropyLoss()( - r, t.softmax(dim=1).to(dtype=torch.float32) - ) - diff = torch.mean( - torch.abs( - r.softmax(dim=1).to(dtype=torch.float32) - - t.softmax(dim=1).to(dtype=torch.float32) - ) - ) - return (cross_entropy, diff) - - -failed_cases = [] -# for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: - extra_kwargs["attn_name"] = ATTN_NAME - if ( - "granite-3.3-8b-instruct" in model_variant - and USE_DISTRIBUTED - and dist.get_world_size() == 4 - ): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT - - if local_rank == 0: - dprint(f"*** testing program {program_id} ***") - dprint( - f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" +if not args.skip_validation: + with stagger_region(args.stagger_load): + validation_model = get_model( + architecture="hf_pretrained", + device_type="cpu", + data_type=None if is_fp8 else torch.float32, + fused_weights=False, + **model_path_kwargs, + **distributed_kwargs, ) + validation_model.eval() - if not args.skip_validation: + for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: + dprint(f"Working on program_id: {program_id}") # attempt to load the cpu validation info if it is already computed cpu_validation_info = __load_validation_info( model_variant, @@ -691,6 +627,80 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): ) ) +if args.stop_after_info_outputs: + dprint("CPU validation outputs saved. Exiting as requested.") + exit(0) + +################################################################ + +fx_config.backed_size_oblivious = True +model.compile(backend="sendnn", options={"sendnn.dynamic": True}) +__maybe_prepare_fp8_weights(model, is_fp8) + +# warmup with any input so compiler produces criteria json +# TODO: Swap this with __prepare_inputs once fix for shape_id is available +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) +prompt_list = [torch.arange(0, 64, dtype=torch.int64)] +# matching vllm warmup to pad to 2 on fp8, and no pad for fp16 +if is_fp8: + prompt_list = prompt_list * 2 +input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) +extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + +extra_kwargs["attn_name"] = ATTN_NAME +if ( + "granite-3.3-8b-instruct" in model_variant + and USE_DISTRIBUTED + and dist.get_world_size() == 4 +): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT +warmup_model( + model, + input_ids, + max_new_tokens=max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + prefill_chunk_size=args.prefill_chunk_size, + **extra_kwargs, +) + +if USE_DISTRIBUTED: + # wait for rank0 to be finished as it is the only one generating the criteria json + # this is needed since otherwise we may run into a race condition + torch.distributed.barrier() + +# metric calculator based on the cross-entropy and mean diff for each decode step +def __metric_calculator(r: torch.Tensor, t: torch.Tensor): + cross_entropy = torch.nn.CrossEntropyLoss()( + r, t.softmax(dim=1).to(dtype=torch.float32) + ) + diff = torch.mean( + torch.abs( + r.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(dtype=torch.float32) + ) + ) + return (cross_entropy, diff) + + +failed_cases = [] +# for each program and valid prompt (batch size, sequence length) +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: + extra_kwargs["attn_name"] = ATTN_NAME + if ( + "granite-3.3-8b-instruct" in model_variant + and USE_DISTRIBUTED + and dist.get_world_size() == 4 + ): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT + + if local_rank == 0: + dprint(f"*** testing program {program_id} ***") + dprint( + f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" + ) + + if not args.skip_validation: if args.test_type == "metrics": aiu_validation_info = extract_validation_information( model, From 555d03a1a871e07516cd7e9dd4f083e0df503c8d Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Wed, 12 Nov 2025 17:41:37 -0500 Subject: [PATCH 2/6] Revert original dpp, add refactored dpp Signed-off-by: Rafael Vasquez --- .../scripts/drive_paged_programs.py | 176 ++-- .../scripts/refactored_dpp.py | 944 ++++++++++++++++++ 2 files changed, 1028 insertions(+), 92 deletions(-) create mode 100644 aiu_fms_testing_utils/scripts/refactored_dpp.py diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 6a69ba1..557d994 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -169,11 +169,6 @@ action="store_true", help="set to true to save cpu validation outputs for later consumption", ) -parser.add_argument( - "--stop_after_info_outputs", - action="store_true", - help="set to true to stop after cpu validation outputs have been saved", -) parser.add_argument( "--prioritize_large_batch_sizes", action="store_true", @@ -367,6 +362,7 @@ def __load_validation_info( dist.get_rank() == 0 ) +# not validation with stagger_region(args.stagger_load): model = get_model( architecture="hf_pretrained", @@ -378,6 +374,56 @@ def __load_validation_info( ) model.eval() +fx_config.backed_size_oblivious = True +model.compile(backend="sendnn", options={"sendnn.dynamic": True}) + +__maybe_prepare_fp8_weights(model, is_fp8) + +# is validation +if not args.skip_validation: + with stagger_region(args.stagger_load): + validation_model = get_model( + architecture="hf_pretrained", + device_type="cpu", + data_type=None if is_fp8 else torch.float32, + fused_weights=False, + **model_path_kwargs, + **distributed_kwargs, + ) + validation_model.eval() + +# warmup with any input so compiler produces criteria json +# TODO: Swap this with __prepare_inputs once fix for shape_id is available +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) +prompt_list = [torch.arange(0, 64, dtype=torch.int64)] +# matching vllm warmup to pad to 2 on fp8, and no pad for fp16 +if is_fp8: + prompt_list = prompt_list * 2 +input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) +extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + +extra_kwargs["attn_name"] = ATTN_NAME +if ( + "granite-3.3-8b-instruct" in model_variant + and USE_DISTRIBUTED + and dist.get_world_size() == 4 +): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT +warmup_model( + model, + input_ids, + max_new_tokens=max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + prefill_chunk_size=args.prefill_chunk_size, + **extra_kwargs, +) + +if USE_DISTRIBUTED: + # wait for rank0 to be finished as it is the only one generating the criteria json + # this is needed since otherwise we may run into a race condition + torch.distributed.barrier() + @dataclass class ProgramInfo: @@ -387,6 +433,7 @@ class ProgramInfo: prompt_length_limit: int prompt_length_limit_type: str + def parse_program_limit(limit_str: str) -> tuple[int, str]: matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)") @@ -403,7 +450,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: limit_val = int(match.group(2)) return limit_val, limit_type -# TODO: Add a check or logic for case that prog criteria json must exist if saving CPU outputs + with open(args.program_criteria_json_path, "r") as f: program_criteria_json_list = json.load(f)["programs"] program_criteria_list = [] @@ -576,20 +623,39 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" ) -if not args.skip_validation: - with stagger_region(args.stagger_load): - validation_model = get_model( - architecture="hf_pretrained", - device_type="cpu", - data_type=None if is_fp8 else torch.float32, - fused_weights=False, - **model_path_kwargs, - **distributed_kwargs, + +# metric calculator based on the cross-entropy and mean diff for each decode step +def __metric_calculator(r: torch.Tensor, t: torch.Tensor): + cross_entropy = torch.nn.CrossEntropyLoss()( + r, t.softmax(dim=1).to(dtype=torch.float32) + ) + diff = torch.mean( + torch.abs( + r.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(dtype=torch.float32) ) - validation_model.eval() + ) + return (cross_entropy, diff) - for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: - dprint(f"Working on program_id: {program_id}") + +failed_cases = [] +# for each program and valid prompt (batch size, sequence length) +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: + extra_kwargs["attn_name"] = ATTN_NAME + if ( + "granite-3.3-8b-instruct" in model_variant + and USE_DISTRIBUTED + and dist.get_world_size() == 4 + ): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT + + if local_rank == 0: + dprint(f"*** testing program {program_id} ***") + dprint( + f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" + ) + + if not args.skip_validation: # attempt to load the cpu validation info if it is already computed cpu_validation_info = __load_validation_info( model_variant, @@ -627,80 +693,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: ) ) -if args.stop_after_info_outputs: - dprint("CPU validation outputs saved. Exiting as requested.") - exit(0) - -################################################################ - -fx_config.backed_size_oblivious = True -model.compile(backend="sendnn", options={"sendnn.dynamic": True}) -__maybe_prepare_fp8_weights(model, is_fp8) - -# warmup with any input so compiler produces criteria json -# TODO: Swap this with __prepare_inputs once fix for shape_id is available -# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) -prompt_list = [torch.arange(0, 64, dtype=torch.int64)] -# matching vllm warmup to pad to 2 on fp8, and no pad for fp16 -if is_fp8: - prompt_list = prompt_list * 2 -input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) -extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - -extra_kwargs["attn_name"] = ATTN_NAME -if ( - "granite-3.3-8b-instruct" in model_variant - and USE_DISTRIBUTED - and dist.get_world_size() == 4 -): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT -warmup_model( - model, - input_ids, - max_new_tokens=max_new_tokens, - compile_dynamic_sendnn=True, - stagger_update_lazyhandle=args.stagger_update_lazyhandle, - prefill_chunk_size=args.prefill_chunk_size, - **extra_kwargs, -) - -if USE_DISTRIBUTED: - # wait for rank0 to be finished as it is the only one generating the criteria json - # this is needed since otherwise we may run into a race condition - torch.distributed.barrier() - -# metric calculator based on the cross-entropy and mean diff for each decode step -def __metric_calculator(r: torch.Tensor, t: torch.Tensor): - cross_entropy = torch.nn.CrossEntropyLoss()( - r, t.softmax(dim=1).to(dtype=torch.float32) - ) - diff = torch.mean( - torch.abs( - r.softmax(dim=1).to(dtype=torch.float32) - - t.softmax(dim=1).to(dtype=torch.float32) - ) - ) - return (cross_entropy, diff) - - -failed_cases = [] -# for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: - extra_kwargs["attn_name"] = ATTN_NAME - if ( - "granite-3.3-8b-instruct" in model_variant - and USE_DISTRIBUTED - and dist.get_world_size() == 4 - ): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT - - if local_rank == 0: - dprint(f"*** testing program {program_id} ***") - dprint( - f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" - ) - - if not args.skip_validation: if args.test_type == "metrics": aiu_validation_info = extract_validation_information( model, diff --git a/aiu_fms_testing_utils/scripts/refactored_dpp.py b/aiu_fms_testing_utils/scripts/refactored_dpp.py new file mode 100644 index 0000000..766edb7 --- /dev/null +++ b/aiu_fms_testing_utils/scripts/refactored_dpp.py @@ -0,0 +1,944 @@ +import argparse +from dataclasses import dataclass +import datetime +import itertools +import json +import os +from pathlib import Path +import random +import time +from itertools import dropwhile +import re +from typing import Any, Dict, List, Optional, Tuple + +import torch +from fms.models import get_model +from fms.utils.generation import pad_input_ids +from torch import distributed as dist +from torch.fx.experimental import _config as fx_config +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.validation import ( + GoldenTokenHook, + LogitsExtractorHook, + ValidationInfo, + capture_level_1_metrics, + extract_validation_information, + filter_failed_level_1_cases, + find_validation_info_path, + get_validation_info_path, + load_validation_information, + top_k_loss_calculator, +) +from aiu_fms_testing_utils.utils import ( + get_pad_size, + sample_rag_factoid_requests, + sample_sharegpt_requests, + stagger_region, + warmup_model, +) +from aiu_fms_testing_utils.utils.aiu_setup import aiu_dist_setup, dprint, local_rank +from aiu_fms_testing_utils.utils.paged import ( + ProgramCriteria, + get_programs_prompts, + KVCACHE_NUM_BLOCKS_HINT, +) +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string + + +@dataclass +class ProgramInfo: + program_id: str + batch_size_limit: int + batch_size_limit_type: str + prompt_length_limit: int + prompt_length_limit_type: str + +def parse_cli_args() -> argparse.Namespace: + """ + Initializes the argument parser and parses command-line arguments. + + Returns: + argparse.Namespace: An object containing the parsed arguments. + """ + + parser = argparse.ArgumentParser( + description="Script which will drive paged programs for debugging" + ) + + parser.add_argument( + "--programs", + metavar="N", + type=str, + nargs="*", + default=[], + help=""" + The list of programs to run. This would take a list where each element would be one of program_id OR :,. + If program_id is specified any prompt that would result in this program would be selected. + If :, is specified, then with the given program_id, select a prompt that satisfies min_batch and min_prompt_length (if none exists, a message will be printed to warn the user) + If this list is empty, each program will be run once with any prompt that would result in this program being selected. + """, + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=8, + help="set this if you want to change the number of tokens generated per sequence (1 prefill + max_new_tokens-1 decodes). Note: If this value is larger than 64, this may result in switching decode programs mid generation", + ) + parser.add_argument( + "--distributed", + action="store_true", + help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", + ) + parser.add_argument( + "--model_variant", + type=str, + default="ibm-ai-platform/micro-g3.3-8b-instruct-1b", + help="The model id or path to use for this test. Note: must be a huggingface format", + ) + parser.add_argument( + "--timing", + type=str, + choices=["e2e", "per-token"], + default="", + help="if set, how to time the generation of tokens, e2e or per-token", + ) + parser.add_argument( + "--program_criteria_json_path", + type=str, + help="path to json file containing the program criteria list", + ) + parser.add_argument( + "--dataset_path", + type=str, + help="path to dataset", + ) + parser.add_argument( + "--dataset_type", + type=str, + choices=["rag_factoid", "sharegpt", "custom"], + default="sharegpt", + help="selects the correct dataset type for sampling. Must be one of rag_factoid or sharegpt", + ) + parser.add_argument( + "--test_type", + type=str, + choices=["tokens", "metrics"], + default="metrics", + help="set the type of the test that you would like to run. If metrics, will inject tokens and get metrics. If tokens, will not inject tokens and get tokens", + ) + + parser.add_argument( + "--cross_entropy_threshold", + type=float, + default=2.5, + help="threshold to denote passing/failing a given iteration", + ) + + parser.add_argument( + "--failure_rate_threshold", + type=float, + default=0.1, + help="the threshold which denotes whether to pass or fail the test. The failure threshold is defined as the number of failing iterations (cross_entropy) over the total iterations. If this value exceeds the failure_rate_threshold, we will fail the test", + ) + + parser.add_argument( + "--attention_type", + type=str, + default="paged", + choices=["paged", "paged_fp8"], + help="The attention type to use", + ) + parser.add_argument( + "--prefill_chunk_size", + type=int, + default=0, + help="if > 0, activate chunked prefill, with chunk_size=this_argument. Only works with paged attention variants.", + ) + parser.add_argument( + "--stagger_load", + type=int, + default=0, + help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes", + ) + parser.add_argument( + "--stagger_update_lazyhandle", + type=int, + default=0, + help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes", + ) + parser.add_argument( + "--dist_timeout", + type=int, + default=0, + help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group", + ) + parser.add_argument( + "--skip_validation", + action="store_true", + help="set to true to skip cpu validation", + ) + parser.add_argument( + "--validation_info_outputs_dir", + type=str, + default="/home/senuser/models/validation_info", + help="path to directory containing validation info outputs", + ) + parser.add_argument( + "--save_validation_info_outputs", + action="store_true", + help="set to true to save cpu validation outputs for later consumption", + ) + parser.add_argument( + "--prioritize_large_batch_sizes", + action="store_true", + help="set to true if you would like to prioritize large batch sizes", + ) + parser.add_argument( + "--enforce_homogeneous_prompt_programs", + action="store_true", + help="set to true ensure that all prompts hit the same prompt program for a given test", + ) + + return parser.parse_args() + + +def __prepare_inputs( + batch_size: int, + seq_length: int, + tokenizer: AutoTokenizer, + sampler, + dataset_path: str, + allow_truncation: bool, + enforce_sizes: List[int] = [], + seed: int = 0, +): + start = time.time() + prompts_and_sizes, sample_key = sampler( + dataset_path, + batch_size, + tokenizer, + 32, + seq_length * 2 if allow_truncation else seq_length, + seed, + enforce_sizes=enforce_sizes, + truncation=allow_truncation, + return_key=True, + ) + end = time.time() + if local_rank == 0: + dprint(f"extracted prompts in {(end - start):.4f} seconds") + prompt_list = [] + for prompt, size in prompts_and_sizes: + encoded = tokenizer.encode(prompt, return_tensors="pt").squeeze(0) + if size > seq_length: + assert allow_truncation + encoded = encoded[:seq_length] + prompt_list.append(encoded) + + if len(prompt_list) < batch_size: + dprint( + f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt." + ) + prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list + + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) + extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + return input_ids, extra_kwargs, sample_key + + +def __maybe_prepare_fp8_weights(model: torch.nn.Module, is_fp8: bool): + if is_fp8: + for name, param in model.named_parameters(): + if param.dtype == torch.bfloat16: + if param.max() > torch.finfo(torch.float16).max: + dprint( + f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended." + ) + param.data = param.data.to(dtype=torch.float16) + + +def __load_validation_info( + model_variant, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + seed, + cpu_dtype: str, + attn_type: str, + validation_info_outputs_dir: str, + sample_key: str | None = None, +): + full_path = find_validation_info_path( + validation_info_dir=validation_info_outputs_dir, + model_variant=model_variant, + batch_size=batch_size, + seq_length=seq_length, + max_new_tokens=max_new_tokens, + seed=seed, + attn_type=attn_type, + version_allow_decrement=True, + dtype=cpu_dtype, + sample_key=sample_key, + ) + if full_path is not None: + dprint(f"cpu validation info found for seed={seed} -- loading it") + return load_validation_information(full_path, "logits", batch_size, tokenizer) + else: + return None + +def parse_program_limit(limit_str: str) -> tuple[int, str]: + matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)") + + # Default limit to min to maintain backwards compat + try: + limit_type = ">=" + limit_val = int(limit_str) + except ValueError: + limit_type = None + match = matcher.fullmatch(limit_str) + if match is None: + raise ValueError("Program not well formatted, wrong limit type") + limit_type = match.group(1) + limit_val = int(match.group(2)) + return limit_val, limit_type + +# metric calculator based on the cross-entropy and mean diff for each decode step +def __metric_calculator(r: torch.Tensor, t: torch.Tensor): + cross_entropy = torch.nn.CrossEntropyLoss()( + r, t.softmax(dim=1).to(dtype=torch.float32) + ) + diff = torch.mean( + torch.abs( + r.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(dtype=torch.float32) + ) + ) + return (cross_entropy, diff) + + +def get_model_path_kwargs(model_variant: str) -> Dict[str, Any]: + + model_path_kwargs = {} + if os.path.exists(model_variant): + model_path_kwargs["model_path"] = model_variant + else: + model_path_kwargs["variant"] = model_variant + + return model_path_kwargs + +def get_distributed_kwargs(is_distributed: bool, dist_timeout: str) -> Dict[str, Any]: + + distributed_kwargs = {} + if is_distributed: + if dist_timeout > 0: + # Default timeout: + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) + dprint(f"NOTICE: init_process_group timeout set to {dist_timeout} minutes") + else: + dist.init_process_group() + + aiu_dist_setup(dist.get_rank(), dist.get_world_size()) + distributed_kwargs["distributed_strategy"] = "tp" + distributed_kwargs["group"] = dist.group.WORLD + save_validation_info_outputs = save_validation_info_outputs and ( + dist.get_rank() == 0 + ) + + return distributed_kwargs + +def get_sampler(dataset_type: str, dataset_path:str, tokenizer: AutoTokenizer): + + custom_shape = None + if dataset_type == "custom": + if local_rank == 0: + dprint( + "Using custom prompts from user, programs parameter will be ignored as it will be determined by user prompt" + ) + directory = Path(dataset_path) + if not directory.is_dir(): + dprint("when using a custom dataset, you must provide a directory") + exit() + + result = [] + for fp in directory.iterdir(): + if fp.is_file(): + try: + content = fp.read_text() + result.append((content, get_pad_size(len(tokenizer.encode(content))))) + except Exception as e: + print(f"Error while reading {fp} for custom dataset: {e}") + exit() + + custom_shape = (len(result), max([_[1] for _ in result])) + + def __custom_line_sampler(*args, **kwargs): + return_key = kwargs.get("return_key", False) + sample_key = format_kwargs_to_string(**kwargs) + if return_key: + return result, sample_key + return result + + sampler = __custom_line_sampler + allow_truncation = False + elif dataset_type == "rag_factoid": + sampler = sample_rag_factoid_requests + allow_truncation = False + elif dataset_type == "sharegpt": + sampler = sample_sharegpt_requests + allow_truncation = True + else: + raise ValueError("dataset_type must be one of rag_factoid or sharegpt") + + return sampler, allow_truncation, custom_shape + +def load_model( + device_type: str, + model_variant: str, + is_fp8: bool, + distributed_kwargs: Dict[str, Any], + stagger_load: int, + is_validation: bool = False, +): + + device_type = "cpu" + dtype = None if is_fp8 else (torch.float32 if is_validation else torch.float16) + + model_path_kwargs = get_model_path_kwargs(model_variant) + + with stagger_region(stagger_load): + model = get_model( + architecture="hf_pretrained", + device_type=device_type, + data_type=dtype, + fused_weights=False, + **model_path_kwargs, + **distributed_kwargs, + ) + + model.eval() + + # Compile if it's not the validation model + if not is_validation: + fx_config.backed_size_oblivious = True + model.compile(backend="sendnn", options={"sendnn.dynamic": True}) + + return model + +def get_programs_to_test(programs, program_criteria_list): + + programs_to_test = [] + for program_str in programs: + enforce_prompt_split = program_str.split(":") + program_id = enforce_prompt_split[0] + if len(enforce_prompt_split) == 1: + programs_to_test.append( + ProgramInfo(program_id, 0, ">=", 0, ">=") + ) # this will always satisfy + else: + enforce_batch_size, enforce_prompt_length = ( + _ for _ in enforce_prompt_split[1].split(",") + ) + + # Default limit to min to maintain backwards compat + enforce_batch_size_val, enforce_batch_size_type = parse_program_limit( + enforce_batch_size + ) + enforce_prompt_length_val, enforce_prompt_length_type = parse_program_limit( + enforce_prompt_length + ) + + programs_to_test.append( + ProgramInfo( + program_id, + enforce_batch_size_val, + enforce_batch_size_type, + enforce_prompt_length_val, + enforce_prompt_length_type, + ) + ) + + if len(programs_to_test) == 0: + programs_to_test = [ + ProgramInfo(str(p.program_id), 0, ">=", 0, ">=") + for p in program_criteria_list + ] + + return programs_to_test + + +def get_program_prompt_list( + program_map, + dataset_path: str, + enforce_homogeneous_prompt_programs: bool, + programs_to_test: List[ProgramInfo], + program_criteria_list: List[ProgramCriteria], + tokenizer: AutoTokenizer, + sampler, + allow_truncation: bool, + custom_shape: Optional[Tuple[int, int]], +): + + # select prompts that fit the batch size criteria + valid_prompts = [] + if custom_shape: + for program_criteria_seq, valid_prompt_shapes in program_map.items(): + for valid_prompt_shape in valid_prompt_shapes: + if valid_prompt_shape == custom_shape: + enforce_sizes = [valid_prompt_shape[1]] + input_ids, extra_kwargs, sample_key = __prepare_inputs( + batch_size=valid_prompt_shape[0], + seq_length=valid_prompt_shape[1], + tokenizer=tokenizer, + sampler=sampler, + dataset_path=dataset_path, + allow_truncation=allow_truncation, + enforce_sizes=enforce_sizes, + ) + valid_prompts = [ + ( + program_criteria_seq[0].program_id, + custom_shape, + input_ids, + extra_kwargs, + sample_key, + ) + ] + break + if len(valid_prompts) > 0: + break + else: + for program_info in programs_to_test: + program_id = program_info.program_id + + filtered_program_map = program_map + if program_id.isnumeric(): + filtered_program_map = { + k: v + for k, v in program_map.items() + if k[0] == program_criteria_list[int(program_id)] + } + used_keys = set() + # for each program, we need to check if we have a shape that satisfies the --programs request + for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): + # if ? or numeric => we need to check if we have found at least one valid key to stop + if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: + break + # if * => we need to see if we have found the first key to see if we should skip + elif program_id == "*" and program_seq_key[0] in used_keys: + continue + + for valid_prompt_shape in valid_prompt_shapes: + # make sure the criteria for batch limit and prompt limit is satisfied + # eval is safe here because we have limited what type and limit can be before + + batch_check = eval( + f"valid_prompt_shape[0] {program_info.batch_size_limit_type} {program_info.batch_size_limit}" + ) + prompt_check = eval( + f"valid_prompt_shape[1] {program_info.prompt_length_limit_type} {program_info.prompt_length_limit}" + ) + if batch_check and prompt_check: + # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length + # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning + # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user + enforce_sizes = [valid_prompt_shape[1]] + if enforce_homogeneous_prompt_programs: + # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) + possible_seq_lengths = [ + _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) + ] + # favor sequences that are close to the valid prompt length + possible_seq_lengths.reverse() + enforce_sizes = enforce_sizes + list( + itertools.islice( + itertools.cycle(possible_seq_lengths), + valid_prompt_shape[0] - 1, + ) + ) + try: + input_ids, extra_kwargs, sample_key = __prepare_inputs( + batch_size=valid_prompt_shape[0], + seq_length=valid_prompt_shape[1], + tokenizer=tokenizer, + sampler=sampler, + dataset_path=dataset_path, + allow_truncation=allow_truncation, + enforce_sizes=enforce_sizes, + ) + valid_prompts.append( + ( + program_seq_key[0], + valid_prompt_shape, + input_ids, + extra_kwargs, + sample_key, + ) + ) + used_keys.add(program_seq_key[0]) + break + except ValueError: + dprint( + f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" + ) + + if len(used_keys) == 0 and local_rank == 0: + dprint( + f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{program_info.batch_size_limit_type}{program_info.batch_size_limit} and prompt_length{program_info.prompt_length_limit_type}{program_info.prompt_length_limit}" + ) + + return valid_prompts + + +def run_validation_tests( + args: argparse.Namespace, + model: torch.nn.Module, + validation_model: Optional[torch.nn.Module], + program_id: int, + valid_prompt, + input_ids: torch.Tensor, + extra_kwargs: Dict[str, Any], + sample_key: str, + attn_name: str, + cpu_dtype: str, + tokenizer: AutoTokenizer, +) -> None: + + if local_rank == 0: + dprint(f"*** testing program {program_id} ***") + dprint( + f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" + ) + + cpu_validation_info: Optional[ValidationInfo] = None + if not args.skip_validation: + # attempt to load the cpu validation info if it is already computed + cpu_validation_info = __load_validation_info( + model_variant=args.model_variant, + batch_size=valid_prompt[0], + seq_length=valid_prompt[1], + max_new_tokens=args.max_new_tokens, + tokenizer=tokenizer, + seed=0, + cpu_dtype=cpu_dtype, + attn_type=attn_name, + validation_info_outputs_dir=args.validation_info_outputs_dir, + sample_key=sample_key, + ) + # if the cpu validation info is not yet computed, compute it + if cpu_validation_info is None and validation_model is not None: + cpu_validation_info = extract_validation_information( + validation_model, + input_ids, + args.max_new_tokens, + LogitsExtractorHook(), + attn_algorithm="math", + **extra_kwargs, + ) + # save the cpu validation info if requested + if args.save_validation_info_outputs: + cpu_validation_info.save( + get_validation_info_path( + validation_info_dir=args.validation_info_outputs_dir, + model_variant=args.model_variant, + batch_size=valid_prompt[0], + seq_length=valid_prompt[1], + max_new_tokens=args.max_new_tokens, + seed=0, + attn_type=attn_name, + dtype=cpu_dtype, + sample_key=sample_key, + ) + ) + + golden_hook = None + if args.test_type == "metrics": + if not args.skip_validation and cpu_validation_info: + golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens")) + + aiu_validation_info = extract_validation_information( + model, + input_ids, + args.max_new_tokens, + golden_hook, + last_n_tokens=64, + timing=args.timing, + prefill_chunk_size=args.prefill_chunk_size, + **extra_kwargs, + ) + + if args.test_type == "metrics": + process_metrics_test ( + cross_entropy_threshold=args.cross_entropy_threshold, + failure_rate_threshold=args.failure_rate_threshold, + aiu_validation_info=aiu_validation_info, + cpu_validation_info=cpu_validation_info, + program_id=program_id, + prompt_shape=valid_prompt, + tokenizer=tokenizer + ) + + elif args.test_type == "tokens": + process_tokens_test ( + max_new_tokens=args.max_new_tokens, + aiu_validation_info=aiu_validation_info, + cpu_validation_info=cpu_validation_info, + program_id=program_id, + tokenizer=tokenizer + ) + + if args.skip_validation and local_rank == 0: + for sentence_idx, test_sentence in enumerate( + aiu_validation_info.get_info("tokens") + ): + tokens_prompt = [t.item() for t in test_sentence[:-args.max_new_tokens]] + aiu_tokens_generated = [t.item() for t in test_sentence[-args.max_new_tokens:]] + dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") + dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}") + dprint(f"AIU tokens:\n{aiu_tokens_generated}") + dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + + +def process_metrics_test( + cross_entropy_threshold: float, + failure_rate_threshold: float, + aiu_validation_info: ValidationInfo, + cpu_validation_info: ValidationInfo, + program_id: str, + prompt_shape: Tuple[int, int], + tokenizer: AutoTokenizer, +) -> None: + + level_1_metrics = capture_level_1_metrics( + cpu_validation_info.get_info("logits"), + aiu_validation_info.get_info("logits"), + top_k_loss_calculator(20, __metric_calculator), + ) + + if local_rank == 0: + cpu_tokens = cpu_validation_info.get_info("tokens") + for sentence_idx, token_idx, metrics_value in level_1_metrics: + aiu_token = torch.argmax( + aiu_validation_info.get_info("logits")[sentence_idx][token_idx], dim=-1 + ) + cpu_token = cpu_tokens[sentence_idx][prompt_shape[1] + token_idx] + aiu_str = tokenizer.decode(aiu_token).replace( + "\n", "" + ) # remove newlines for readability + cpu_str = tokenizer.decode(cpu_token).replace( + "\n", "" + ) # remove newlines for readability + dprint( + f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"' + ) + + ce_fail_responses = filter_failed_level_1_cases( + level_1_metrics, lambda m: m[0] >= cross_entropy_threshold + ) + + failure_rate = len(ce_fail_responses) / len(level_1_metrics) if level_1_metrics else 0.0 + + if failure_rate >= failure_rate_threshold: + dprint(f"[FAIL] Program {program_id} failed with rate {failure_rate:.4f} >= threshold {failure_rate_threshold}.") + + if local_rank == 0: + dprint(f"[PASS] Program {program_id} passed. Failure Rate: {failure_rate:.4f}.") + + +def process_tokens_test( + max_new_tokens: int, + aiu_validation_info: ValidationInfo, + cpu_validation_info: ValidationInfo, + program_id: str, + tokenizer: AutoTokenizer, +) -> None: + + if local_rank != 0: + return + + for sentence_idx, (reference_sentence, test_sentence) in enumerate( + zip( + cpu_validation_info.get_info("tokens"), + aiu_validation_info.get_info("tokens"), + ) + ): + tokens_prompt = [ + t.item() for t in reference_sentence[:-max_new_tokens] + ] + cpu_tokens_generated = [ + t.item() for t in reference_sentence[-max_new_tokens:] + ] + aiu_tokens_generated = [ + t.item() for t in test_sentence[-max_new_tokens:] + ] + tokens_prompt_without_pad = list( + dropwhile(lambda x: x == tokenizer.pad_token_id, tokens_prompt) + ) + prompt_length = len( + [token_id for token_id in tokens_prompt_without_pad] + ) + dprint(f"Prompt Length: {prompt_length}") + dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") + dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt_without_pad)}") + dprint(f"CPU tokens:\n{cpu_tokens_generated}") + dprint(f"AIU tokens:\n{aiu_tokens_generated}") + dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}") + dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + + +## ENV SETUP ## + +args = parse_cli_args() + +if args.skip_validation and args.test_type == "metrics": + dprint("When skipping validation, only test_type will be ignored") + +attention_map = { + "sdpa": "sdpa_causal", + "paged": "spyre_paged_attn", + "math_fp8": "math_fp8", + "paged_fp8": "spyre_paged_attn_fp8", +} +ATTN_NAME = attention_map[args.attention_type] + +is_fp8 = "fp8" in args.attention_type +CPU_DTYPE = "fp8" if is_fp8 else "fp32" + +torch.manual_seed(42) +torch.set_grad_enabled(False) + +os.environ["COMPILATION_MODE"] = "offline_decoder" +os.environ["DT_PROG_CRITERIA_FILEPATH"] = args.program_criteria_json_path +if ("VLLM_DT_MAX_CONTEXT_LEN" not in os.environ or "VLLM_DT_MAX_BATCH_SIZE" not in os.environ): + if local_rank == 0: + dprint( + "Please specify VLLM_DT_MAX_CONTEXT_LEN and VLLM_DT_MAX_BATCH_SIZE environment variables" + ) + exit() +max_batch_size = int(os.environ["VLLM_DT_MAX_BATCH_SIZE"]) +max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) + + +## MODEL LOADING ## + +# Get distributed kwargs (empty if not distributed) +distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout) + +model = load_model( + device_type="cpu", + model_variant=args.model_variant, + is_fp8=is_fp8, + distributed_kwargs=distributed_kwargs, + stagger_load=args.stagger_load, + is_validation=False) + +__maybe_prepare_fp8_weights(model, is_fp8) + +# Load validation model +validation_model = None +if not args.skip_validation: + validation_model = load_model( + args.model_variant, is_fp8, distributed_kwargs, args.stagger_load, is_validation=True + ) + +## MODEL WARMUP ## + +# warmup with any input so compiler produces criteria json +# TODO: Swap this with __prepare_inputs once fix for shape_id is available +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) +prompt_list = [torch.arange(0, 64, dtype=torch.int64)] +# matching vllm warmup to pad to 2 on fp8, and no pad for fp16 +if is_fp8: + prompt_list = prompt_list * 2 +input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) +extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + +extra_kwargs["attn_name"] = ATTN_NAME +if ( "granite-3.3-8b-instruct" in args.model_variant and args.distributed and dist.get_world_size() == 4): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT + +warmup_model( + model=model, + input_ids=input_ids, + max_new_tokens=args.max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + prefill_chunk_size=args.prefill_chunk_size, + **extra_kwargs, +) + +if args.distributed: + # wait for rank0 to be finished as it is the only one generating the criteria json + # this is needed since otherwise we may run into a race condition + torch.distributed.barrier() + + +## PREPARE PROGRAM CRITERIA AND PROMPTS ## + +with open(args.program_criteria_json_path, "r") as f: + program_criteria_json_list = json.load(f)["programs"] + program_criteria_list = [] + for i, d in enumerate(program_criteria_json_list): + program_criteria_list.append( + ProgramCriteria( + i, + d["max_batch"], + d["max_tkv"], + d["batch_granularity"], + d["tkv_granularity"], + ) + ) + + programs_to_test = get_programs_to_test(args.programs, program_criteria_list) + + +# FIXME: filter condition for this on prompt and batch +program_map = get_programs_prompts( + program_criteria_list=program_criteria_list, + multiple=64, + max_batch_size=max_batch_size, + max_tkv=max_tkv, + program_cycles=args.max_new_tokens, + prioritize_large_batch_sizes=args.prioritize_large_batch_sizes, +) +for v in program_map.values(): + random.Random(42).shuffle(v) + +tokenizer = AutoTokenizer.from_pretrained(args.model_variant) +sampler, allow_truncation, custom_shape = get_sampler(args.dataset_type, args.dataset_path, tokenizer) + +# Select concrete prompts and program associations +valid_prompts = get_program_prompt_list( + args, + programs_to_test, + program_criteria_list, + program_map, + tokenizer, + sampler, + allow_truncation, + custom_shape +) + +## RUN TESTS ## + +failed_cases = [] +# for each program and valid prompt (batch size, sequence length) +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: + extra_kwargs["attn_name"] = ATTN_NAME + + run_validation_tests( + args=args, + model=model, + validation_model=validation_model, + program_id=program_id, + valid_prompt=valid_prompt, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + sample_key=sample_key, + attn_name=ATTN_NAME, + cpu_dtype=CPU_DTYPE, + tokenizer=tokenizer, + ) \ No newline at end of file From 46fefc1c7c8e30a7636ae89e3aea37034331e9e0 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Wed, 12 Nov 2025 17:42:39 -0500 Subject: [PATCH 3/6] Small cleanup Signed-off-by: Rafael Vasquez --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 2 -- aiu_fms_testing_utils/scripts/refactored_dpp.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 557d994..c10652b 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -362,7 +362,6 @@ def __load_validation_info( dist.get_rank() == 0 ) -# not validation with stagger_region(args.stagger_load): model = get_model( architecture="hf_pretrained", @@ -379,7 +378,6 @@ def __load_validation_info( __maybe_prepare_fp8_weights(model, is_fp8) -# is validation if not args.skip_validation: with stagger_region(args.stagger_load): validation_model = get_model( diff --git a/aiu_fms_testing_utils/scripts/refactored_dpp.py b/aiu_fms_testing_utils/scripts/refactored_dpp.py index 766edb7..d5ed40b 100644 --- a/aiu_fms_testing_utils/scripts/refactored_dpp.py +++ b/aiu_fms_testing_utils/scripts/refactored_dpp.py @@ -941,4 +941,4 @@ def process_tokens_test( attn_name=ATTN_NAME, cpu_dtype=CPU_DTYPE, tokenizer=tokenizer, - ) \ No newline at end of file + ) From 0f47515bef2ab883372319de6bb0e8fa65817e74 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Thu, 13 Nov 2025 11:33:50 -0500 Subject: [PATCH 4/6] Fix a few vars and args Signed-off-by: Rafael Vasquez --- .../scripts/refactored_dpp.py | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/refactored_dpp.py b/aiu_fms_testing_utils/scripts/refactored_dpp.py index d5ed40b..03e87fc 100644 --- a/aiu_fms_testing_utils/scripts/refactored_dpp.py +++ b/aiu_fms_testing_utils/scripts/refactored_dpp.py @@ -328,7 +328,7 @@ def get_model_path_kwargs(model_variant: str) -> Dict[str, Any]: return model_path_kwargs -def get_distributed_kwargs(is_distributed: bool, dist_timeout: str) -> Dict[str, Any]: +def get_distributed_kwargs(is_distributed: bool, dist_timeout: str, save_validation_info_outputs: bool) -> Dict[str, Any]: distributed_kwargs = {} if is_distributed: @@ -403,7 +403,6 @@ def load_model( is_validation: bool = False, ): - device_type = "cpu" dtype = None if is_fp8 else (torch.float32 if is_validation else torch.float16) model_path_kwargs = get_model_path_kwargs(model_variant) @@ -631,10 +630,10 @@ def run_validation_tests( # if the cpu validation info is not yet computed, compute it if cpu_validation_info is None and validation_model is not None: cpu_validation_info = extract_validation_information( - validation_model, - input_ids, - args.max_new_tokens, - LogitsExtractorHook(), + model=validation_model, + input_ids=input_ids, + max_new_tokens=args.max_new_tokens, + post_iteration_hook=LogitsExtractorHook(), attn_algorithm="math", **extra_kwargs, ) @@ -660,10 +659,10 @@ def run_validation_tests( golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens")) aiu_validation_info = extract_validation_information( - model, - input_ids, - args.max_new_tokens, - golden_hook, + model=model, + input_ids=input_ids, + max_new_tokens=args.max_new_tokens, + post_iteration_hook=golden_hook, last_n_tokens=64, timing=args.timing, prefill_chunk_size=args.prefill_chunk_size, @@ -825,7 +824,7 @@ def process_tokens_test( ## MODEL LOADING ## # Get distributed kwargs (empty if not distributed) -distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout) +distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout, args.save_validation_info_outputs) model = load_model( device_type="cpu", @@ -841,7 +840,12 @@ def process_tokens_test( validation_model = None if not args.skip_validation: validation_model = load_model( - args.model_variant, is_fp8, distributed_kwargs, args.stagger_load, is_validation=True + device_type="cpu", + model_variant=args.model_variant, + is_fp8=is_fp8, + distributed_kwargs=distributed_kwargs, + stagger_load=args.stagger_load, + is_validation=True ) ## MODEL WARMUP ## @@ -912,14 +916,15 @@ def process_tokens_test( # Select concrete prompts and program associations valid_prompts = get_program_prompt_list( - args, - programs_to_test, - program_criteria_list, - program_map, - tokenizer, - sampler, - allow_truncation, - custom_shape + program_map=program_map, + dataset_path=args.dataset_path, + enforce_homogeneous_prompt_programs=args.enforce_homogeneous_prompt_programs, + programs_to_test=programs_to_test, + program_criteria_list=program_criteria_list, + tokenizer=tokenizer, + sampler=sampler, + allow_truncation=allow_truncation, + custom_shape=custom_shape, ) ## RUN TESTS ## From 340c383aaba4cbb46d4bcbb8a6df35a622f566cb Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Thu, 13 Nov 2025 12:56:24 -0500 Subject: [PATCH 5/6] Refactor of validation and tests Signed-off-by: Rafael Vasquez --- .../scripts/refactored_dpp.py | 111 ++++++++++-------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/refactored_dpp.py b/aiu_fms_testing_utils/scripts/refactored_dpp.py index 03e87fc..581f0ba 100644 --- a/aiu_fms_testing_utils/scripts/refactored_dpp.py +++ b/aiu_fms_testing_utils/scripts/refactored_dpp.py @@ -592,7 +592,7 @@ def get_program_prompt_list( return valid_prompts -def run_validation_tests( +def run_validation( args: argparse.Namespace, model: torch.nn.Module, validation_model: Optional[torch.nn.Module], @@ -604,7 +604,7 @@ def run_validation_tests( attn_name: str, cpu_dtype: str, tokenizer: AutoTokenizer, -) -> None: +): if local_rank == 0: dprint(f"*** testing program {program_id} ***") @@ -669,47 +669,17 @@ def run_validation_tests( **extra_kwargs, ) - if args.test_type == "metrics": - process_metrics_test ( - cross_entropy_threshold=args.cross_entropy_threshold, - failure_rate_threshold=args.failure_rate_threshold, - aiu_validation_info=aiu_validation_info, - cpu_validation_info=cpu_validation_info, - program_id=program_id, - prompt_shape=valid_prompt, - tokenizer=tokenizer - ) - - elif args.test_type == "tokens": - process_tokens_test ( - max_new_tokens=args.max_new_tokens, - aiu_validation_info=aiu_validation_info, - cpu_validation_info=cpu_validation_info, - program_id=program_id, - tokenizer=tokenizer - ) - - if args.skip_validation and local_rank == 0: - for sentence_idx, test_sentence in enumerate( - aiu_validation_info.get_info("tokens") - ): - tokens_prompt = [t.item() for t in test_sentence[:-args.max_new_tokens]] - aiu_tokens_generated = [t.item() for t in test_sentence[-args.max_new_tokens:]] - dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") - dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}") - dprint(f"AIU tokens:\n{aiu_tokens_generated}") - dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + return aiu_validation_info, cpu_validation_info -def process_metrics_test( +def run_metrics_test( cross_entropy_threshold: float, - failure_rate_threshold: float, aiu_validation_info: ValidationInfo, cpu_validation_info: ValidationInfo, program_id: str, prompt_shape: Tuple[int, int], tokenizer: AutoTokenizer, -) -> None: +): level_1_metrics = capture_level_1_metrics( cpu_validation_info.get_info("logits"), @@ -737,17 +707,11 @@ def process_metrics_test( ce_fail_responses = filter_failed_level_1_cases( level_1_metrics, lambda m: m[0] >= cross_entropy_threshold ) + failure_rate = len(ce_fail_responses) / len(level_1_metrics) - failure_rate = len(ce_fail_responses) / len(level_1_metrics) if level_1_metrics else 0.0 - - if failure_rate >= failure_rate_threshold: - dprint(f"[FAIL] Program {program_id} failed with rate {failure_rate:.4f} >= threshold {failure_rate_threshold}.") - - if local_rank == 0: - dprint(f"[PASS] Program {program_id} passed. Failure Rate: {failure_rate:.4f}.") - + return failure_rate -def process_tokens_test( +def run_tokens_test( max_new_tokens: int, aiu_validation_info: ValidationInfo, cpu_validation_info: ValidationInfo, @@ -827,10 +791,10 @@ def process_tokens_test( distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout, args.save_validation_info_outputs) model = load_model( - device_type="cpu", - model_variant=args.model_variant, - is_fp8=is_fp8, - distributed_kwargs=distributed_kwargs, + device_type="cpu", + model_variant=args.model_variant, + is_fp8=is_fp8, + distributed_kwargs=distributed_kwargs, stagger_load=args.stagger_load, is_validation=False) @@ -927,14 +891,14 @@ def process_tokens_test( custom_shape=custom_shape, ) -## RUN TESTS ## +## RUN VALIDATION AND TESTS ## failed_cases = [] # for each program and valid prompt (batch size, sequence length) for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: extra_kwargs["attn_name"] = ATTN_NAME - run_validation_tests( + aiu_validation_info, cpu_validation_info = run_validation( args=args, model=model, validation_model=validation_model, @@ -947,3 +911,50 @@ def process_tokens_test( cpu_dtype=CPU_DTYPE, tokenizer=tokenizer, ) + + if args.test_type == "metrics": + failure_rate = run_metrics_test ( + cross_entropy_threshold=args.cross_entropy_threshold, + aiu_validation_info=aiu_validation_info, + cpu_validation_info=cpu_validation_info, + program_id=program_id, + prompt_shape=valid_prompt, + tokenizer=tokenizer + ) + if failure_rate > args.failure_rate_threshold: + failed_cases.append( + (program_id, valid_prompt, failure_rate) + ) + + elif args.test_type == "tokens": + run_tokens_test ( + max_new_tokens=args.max_new_tokens, + aiu_validation_info=aiu_validation_info, + cpu_validation_info=cpu_validation_info, + program_id=program_id, + tokenizer=tokenizer + ) + + else: + raise ValueError("test type must be one of metrics or tokens") + + if args.skip_validation and local_rank == 0: + for sentence_idx, test_sentence in enumerate( + aiu_validation_info.get_info("tokens") + ): + tokens_prompt = [t.item() for t in test_sentence[:-args.max_new_tokens]] + aiu_tokens_generated = [t.item() for t in test_sentence[-args.max_new_tokens:]] + dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") + dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}") + dprint(f"AIU tokens:\n{aiu_tokens_generated}") + dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + +if not args.skip_validation and local_rank == 0: + if len(failed_cases) != 0: + dprint("The test failed with the following cases:") + for failed_case in failed_cases: + dprint( + f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}" + ) + else: + dprint("all tests passed") From 258e03d5b9bcc67ddc7b59097bf3173f695849cc Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Fri, 14 Nov 2025 16:35:42 -0500 Subject: [PATCH 6/6] More cleanup Signed-off-by: Rafael Vasquez --- aiu_fms_testing_utils/.DS_Store | Bin 0 -> 8196 bytes .../scripts/refactored_dpp.py | 75 +++++++++++------- 2 files changed, 47 insertions(+), 28 deletions(-) create mode 100644 aiu_fms_testing_utils/.DS_Store diff --git a/aiu_fms_testing_utils/.DS_Store b/aiu_fms_testing_utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6e14e5a57708bde95781d7a3d6f4cc65953af53f GIT binary patch literal 8196 zcmeHMU2GIp6u#fI(3yeC6lf_sLPnbK*MiGm>5t0NAD{vO+3gPnTz7W{m@u6wJG0v& zAy$bmf@kJjD`l!)!=gtCc;h`@E!`x)x*G1+2l-qd>Lj3U=JJ=}R9d%_uUy>r47M7sg2 zo#xs3nQl3LxjXN0x2hT{ZhcpAOHt?P1B2?pthTUBbKf@_*q-YTSVf;`R$Jhck9`Ht z-Cwl0e6Zj-r4iT9YITL2Qz$vPz}oIOY5Qr*9}2ATfWHmDgtg zhqV16o4Vs_i>FN;%EkHh%T}yxPHySg*0XER{>j=07HDp&N}2OKSB4 zdPdg1mT8r2w`>!zY0(9~<)+kHu|6WQO{rwGTrykA*3;p%iydYbCC*uPp1r{?u=m*~ z>p&GSV01b6WU=>zlEt;_zDQrg9NtP68 zp|nVM_*;O`4aQ%h-%6@16KqinTU|?kk(`Kwg zE4HEy9e4!2*nyqsBfw_Rj{yu}nBZy>V2{B@5k3L}?K7Cbvp9?A39>Ka6}*bqZ~<=; zc;AV@`yoD>2FJ&z!V#Bd!m*q#x{l`_BS{6sQ;lM7HW8>2%8mT~-+TA(|Fh{b){P?& zN8o=J0aW*-d-&;pw@p Dict[str, Any]: +def get_model_kwargs(model_variant: str) -> Dict[str, Any]: model_path_kwargs = {} if os.path.exists(model_variant): @@ -328,6 +328,7 @@ def get_model_path_kwargs(model_variant: str) -> Dict[str, Any]: return model_path_kwargs + def get_distributed_kwargs(is_distributed: bool, dist_timeout: str, save_validation_info_outputs: bool) -> Dict[str, Any]: distributed_kwargs = {} @@ -349,6 +350,7 @@ def get_distributed_kwargs(is_distributed: bool, dist_timeout: str, save_validat return distributed_kwargs + def get_sampler(dataset_type: str, dataset_path:str, tokenizer: AutoTokenizer): custom_shape = None @@ -394,10 +396,11 @@ def __custom_line_sampler(*args, **kwargs): return sampler, allow_truncation, custom_shape + def load_model( device_type: str, - model_variant: str, is_fp8: bool, + model_kwargs: Dict[str, Any], distributed_kwargs: Dict[str, Any], stagger_load: int, is_validation: bool = False, @@ -405,15 +408,13 @@ def load_model( dtype = None if is_fp8 else (torch.float32 if is_validation else torch.float16) - model_path_kwargs = get_model_path_kwargs(model_variant) - with stagger_region(stagger_load): model = get_model( architecture="hf_pretrained", device_type=device_type, data_type=dtype, fused_weights=False, - **model_path_kwargs, + **model_kwargs, **distributed_kwargs, ) @@ -423,9 +424,11 @@ def load_model( if not is_validation: fx_config.backed_size_oblivious = True model.compile(backend="sendnn", options={"sendnn.dynamic": True}) + __maybe_prepare_fp8_weights(model, is_fp8) return model + def get_programs_to_test(programs, program_criteria_list): programs_to_test = [] @@ -468,7 +471,7 @@ def get_programs_to_test(programs, program_criteria_list): return programs_to_test -def get_program_prompt_list( +def get_valid_prompts( program_map, dataset_path: str, enforce_homogeneous_prompt_programs: bool, @@ -591,12 +594,9 @@ def get_program_prompt_list( return valid_prompts - -def run_validation( +def generate_cpu_validation( args: argparse.Namespace, - model: torch.nn.Module, validation_model: Optional[torch.nn.Module], - program_id: int, valid_prompt, input_ids: torch.Tensor, extra_kwargs: Dict[str, Any], @@ -606,12 +606,6 @@ def run_validation( tokenizer: AutoTokenizer, ): - if local_rank == 0: - dprint(f"*** testing program {program_id} ***") - dprint( - f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" - ) - cpu_validation_info: Optional[ValidationInfo] = None if not args.skip_validation: # attempt to load the cpu validation info if it is already computed @@ -653,6 +647,16 @@ def run_validation( ) ) + return cpu_validation_info + +def generate_aiu_validation( + args: argparse.Namespace, + model: torch.nn.Module, + input_ids: torch.Tensor, + cpu_validation_info: Optional[ValidationInfo], + extra_kwargs: Dict[str, Any], +): + golden_hook = None if args.test_type == "metrics": if not args.skip_validation and cpu_validation_info: @@ -669,7 +673,7 @@ def run_validation( **extra_kwargs, ) - return aiu_validation_info, cpu_validation_info + return aiu_validation_info def run_metrics_test( @@ -787,26 +791,27 @@ def run_tokens_test( ## MODEL LOADING ## -# Get distributed kwargs (empty if not distributed) +model_kwargs = get_model_kwargs(args.model_variant) + +# distributed_kwargs is empty if not distributed distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout, args.save_validation_info_outputs) model = load_model( device_type="cpu", - model_variant=args.model_variant, is_fp8=is_fp8, + model_kwargs=model_kwargs, distributed_kwargs=distributed_kwargs, stagger_load=args.stagger_load, - is_validation=False) + is_validation=False + ) -__maybe_prepare_fp8_weights(model, is_fp8) -# Load validation model validation_model = None if not args.skip_validation: validation_model = load_model( device_type="cpu", - model_variant=args.model_variant, is_fp8=is_fp8, + model_kwargs=model_kwargs, distributed_kwargs=distributed_kwargs, stagger_load=args.stagger_load, is_validation=True @@ -879,7 +884,7 @@ def run_tokens_test( sampler, allow_truncation, custom_shape = get_sampler(args.dataset_type, args.dataset_path, tokenizer) # Select concrete prompts and program associations -valid_prompts = get_program_prompt_list( +valid_prompts = get_valid_prompts( program_map=program_map, dataset_path=args.dataset_path, enforce_homogeneous_prompt_programs=args.enforce_homogeneous_prompt_programs, @@ -896,13 +901,19 @@ def run_tokens_test( failed_cases = [] # for each program and valid prompt (batch size, sequence length) for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: + + if local_rank == 0: + dprint(f"*** testing program {program_id} ***") + dprint( + f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}" + ) + extra_kwargs["attn_name"] = ATTN_NAME - aiu_validation_info, cpu_validation_info = run_validation( + # Returns none if skipping validation + cpu_validation_info = generate_cpu_validation( args=args, - model=model, - validation_model=validation_model, - program_id=program_id, + validation_model=validation_model, # could be None if skipping validation valid_prompt=valid_prompt, input_ids=input_ids, extra_kwargs=extra_kwargs, @@ -911,6 +922,14 @@ def run_tokens_test( cpu_dtype=CPU_DTYPE, tokenizer=tokenizer, ) + + aiu_validation_info = generate_aiu_validation( + args=args, + model=model, + input_ids=input_ids, + cpu_validation_info=cpu_validation_info, + extra_kwargs=extra_kwargs, + ) if args.test_type == "metrics": failure_rate = run_metrics_test (