Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 149 additions & 27 deletions aiu_fms_testing_utils/scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from torch import distributed as dist
from torch.fx.experimental import _config as fx_config
from transformers import AutoTokenizer
import numpy as np

from aiu_fms_testing_utils.testing.validation import (
GoldenTokenHook,
LogitsExtractorHook,
capture_level_1_metrics,
extract_validation_information,
filter_failed_level_1_cases,
find_validation_info_path,
get_validation_info_path,
load_validation_information,
Expand Down Expand Up @@ -108,12 +108,26 @@
)

parser.add_argument(
"--cross_entropy_threshold",
"--default_cross_entropy_threshold",
type=float,
default=2.5,
help="threshold to denote passing/failing a given iteration",
)

parser.add_argument(
"--cross_entropy_threshold_path",
type=str,
default=None,
help="path to a file with all expected cross-entropy loss thresholds per program, pre sequence",
)

parser.add_argument(
"--per_sequence_failure_rate_threshold",
type=float,
default=0.1,
help="the threshold which denotes whether to pass or fail the test for a given sequence.",
)

parser.add_argument(
"--failure_rate_threshold",
type=float,
Expand Down Expand Up @@ -172,6 +186,12 @@
action="store_true",
help="set to true ensure that all prompts hit the same prompt program for a given test",
)
parser.add_argument(
"--generate_metrics_path",
type=str,
default=None,
help="if set, will bypass AIU model processing and generate cross-entropy loss thresholds used for testing, and save the metrics to the given path",
)

args = parser.parse_args()

Expand All @@ -180,9 +200,19 @@
model_variant = args.model_variant
DATASET_PATH = args.dataset_path
save_validation_info_outputs = args.save_validation_info_outputs
generate_metrics = args.generate_metrics_path is not None
tokenizer = AutoTokenizer.from_pretrained(model_variant)
custom_shape = None

default_cross_entropy_threshold = float(args.default_cross_entropy_threshold)
program_threshold_dict = {}
# if the path exists, load it as a json
if args.cross_entropy_threshold_path is not None and os.path.exists(
args.cross_entropy_threshold_path
):
with open(args.cross_entropy_threshold_path, "r") as f:
program_threshold_dict = json.load(f)

if args.dataset_type == "custom":
if local_rank == 0:
dprint(
Expand Down Expand Up @@ -332,13 +362,17 @@ def __load_validation_info(

distributed_kwargs = {}
if USE_DISTRIBUTED:
if generate_metrics:
torch.cuda.set_device(local_rank)
if args.dist_timeout > 0:
# Default timeout:
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
dist.init_process_group(
timeout=datetime.timedelta(minutes=args.dist_timeout), backend="gloo"
)
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
else:
dist.init_process_group()
dist.init_process_group(backend="gloo")
aiu_dist_setup(dist.get_rank(), dist.get_world_size())
distributed_kwargs["distributed_strategy"] = "tp"
distributed_kwargs["group"] = dist.group.WORLD
Expand All @@ -349,7 +383,7 @@ def __load_validation_info(
with stagger_region(args.stagger_load):
model = get_model(
architecture="hf_pretrained",
device_type="cpu",
device_type="cuda" if generate_metrics else "cpu",
data_type=None if is_fp8 else torch.float16,
fused_weights=False,
**model_path_kwargs,
Expand All @@ -358,7 +392,8 @@ def __load_validation_info(

model.eval()
fx_config.backed_size_oblivious = True
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
if not generate_metrics:
model.compile(backend="sendnn", options={"sendnn.dynamic": True})

__maybe_prepare_fp8_weights(model, is_fp8)

Expand Down Expand Up @@ -391,14 +426,16 @@ def __load_validation_info(
and dist.get_world_size() == 4
):
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
warmup_model(
model,
input_ids,
max_new_tokens=max_new_tokens,
compile_dynamic_sendnn=True,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
**extra_kwargs,
)

if not generate_metrics:
warmup_model(
model,
input_ids,
max_new_tokens=max_new_tokens,
compile_dynamic_sendnn=True,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
**extra_kwargs,
)

if USE_DISTRIBUTED:
# wait for rank0 to be finished as it is the only one generating the criteria json
Expand Down Expand Up @@ -608,18 +645,19 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
# metric calculator based on the cross-entropy and mean diff for each decode step
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
cross_entropy = torch.nn.CrossEntropyLoss()(
r, t.softmax(dim=1).to(dtype=torch.float32)
r, t.softmax(dim=1).to(device="cpu", dtype=torch.float32)
)
diff = torch.mean(
torch.abs(
r.softmax(dim=1).to(dtype=torch.float32)
- t.softmax(dim=1).to(dtype=torch.float32)
- t.softmax(dim=1).to(device="cpu", dtype=torch.float32)
)
)
return (cross_entropy, diff)


failed_cases = []
per_sentence_failed_cases = []
aggregate_failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
extra_kwargs["attn_name"] = ATTN_NAME
Expand Down Expand Up @@ -675,6 +713,14 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
)

if args.test_type == "metrics":
# if we are generating metrics, all inputs should be on cuda device
if generate_metrics:
input_ids = input_ids.to("cuda")
extra_kwargs = {
k: v.to("cuda") if isinstance(v, torch.Tensor) else v
for k, v in extra_kwargs.items()
}

aiu_validation_info = extract_validation_information(
model,
input_ids,
Expand Down Expand Up @@ -711,12 +757,70 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"'
)

ce_fail_responses = filter_failed_level_1_cases(
level_1_metrics, lambda m: m[0] >= args.cross_entropy_threshold
)
failure_rate = len(ce_fail_responses) / len(level_1_metrics)
if failure_rate >= args.failure_rate_threshold:
failed_cases.append((program_id, valid_prompt, failure_rate))
# if generating metrics, get the 99th percentile ce threshold per sentence
# otherwise test the thresholds
if generate_metrics:
sentence_ce_dict = {}
for sentence_idx, token_idx, metrics_value in level_1_metrics:
sentence_ce_dict.setdefault(sentence_idx, [])
sentence_ce_dict[sentence_idx].append(metrics_value[0])

sentence_ce_threshold = {
k: np.percentile(v, 99) for k, v in sentence_ce_dict.items()
}
if local_rank == 0:
dprint(
f"Program {str(program_id.program_id)} produced the following thresholds:\n{sentence_ce_threshold}"
)
program_threshold_dict[(str(program_id.program_id), sample_key)] = (
sentence_ce_threshold
)
else:
sentence_failures_dict = {}
for sentence_idx, token_idx, metrics_value in level_1_metrics:
program_threshold_key = f"{str(program_id.program_id)},{sample_key}"
if (
len(program_threshold_dict) != 0
and program_threshold_key not in program_threshold_dict
and local_rank == 0
):
dprint(
f"could not find the following key {program_threshold_key}, defaulting to {default_cross_entropy_threshold}"
)
ce_threshold = program_threshold_dict.get(
program_threshold_key,
{str(sentence_idx): default_cross_entropy_threshold},
)[str(sentence_idx)]
sentence_failures_dict.setdefault(sentence_idx, 0)
if metrics_value[0].item() >= ce_threshold:
sentence_failures_dict[sentence_idx] += 1

for sentence_idx, failure_count in sentence_failures_dict.items():
per_sentence_failure_rate = failure_count / max_new_tokens
if (
per_sentence_failure_rate
>= args.per_sequence_failure_rate_threshold
):
per_sentence_failed_cases.append(
(
program_id,
valid_prompt,
sentence_idx,
per_sentence_failure_rate,
)
)

aggregate_failure_rate = sum(sentence_failures_dict.values()) / (
max_new_tokens * len(sentence_failures_dict)
)
if aggregate_failure_rate >= args.failure_rate_threshold:
aggregate_failed_cases.append(
(
program_id,
valid_prompt,
aggregate_failure_rate,
)
)

elif args.test_type == "tokens":
aiu_validation_info = extract_validation_information(
Expand Down Expand Up @@ -784,12 +888,30 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")

if generate_metrics and local_rank == 0:
with open(args.generate_metrics_path, "w") as f:
json_dict = {}
for program_seq, sentence_ce_threshold_dict in program_threshold_dict.items():
program_seq_key = ",".join(program_seq)
json_dict[program_seq_key] = {}
for sentence_i, ce_threshold in sentence_ce_threshold_dict.items():
json_dict[program_seq_key][sentence_i] = float(ce_threshold)

json.dump(json_dict, f, indent=4)

if not args.skip_validation and local_rank == 0:
if len(failed_cases) != 0:
dprint("the test failed with the following cases:")
for failed_case in failed_cases:
if len(aggregate_failed_cases) != 0:
dprint("the test failed with the following aggregate cases:")
for failed_case in aggregate_failed_cases:
dprint(
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}"
)
else:
if len(per_sentence_failed_cases) != 0:
dprint("the test failed with the following per sentence cases:")
for failed_case in per_sentence_failed_cases:
dprint(
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Sentence Index: {failed_case[2]}, Failure Rate: {failed_case[3]}"
)

if len(aggregate_failed_cases) == 0 and len(per_sentence_failed_cases) == 0:
dprint("all tests passed")
1 change: 1 addition & 0 deletions aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def load_validation_information(
f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}"
)

validation_files_paths.sort(key=lambda p: int(p.name.split(".pt")[0]))
validation_info = []
for i, validation_file_path in enumerate(validation_files_paths):
if i == batch_size:
Expand Down
30 changes: 19 additions & 11 deletions aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def generate(
(
torch.zeros(
NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype
),
).to(input_ids.device),
torch.zeros(
NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype
),
).to(input_ids.device),
)
for _ in range(model.config.nlayers)
]
Expand Down Expand Up @@ -304,9 +304,9 @@ def generate(
last_n_tokens = kwargs.get("last_n_tokens", 0)
output, current_kv_cache = model(
input_ids_i,
slot_mapping=slot_mapping_i,
position_ids=position_ids_i,
mask=mask_i,
slot_mapping=slot_mapping_i.to(input_ids.device),
position_ids=position_ids_i.to(input_ids.device),
mask=mask_i.to(input_ids.device),
past_key_value_states=current_kv_cache,
use_cache=kwargs["use_cache"],
last_n_tokens=last_n_tokens,
Expand Down Expand Up @@ -342,8 +342,10 @@ def generate(
# mask is no longer used here
kwargs["mask"] = None
kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1
kwargs["position_ids"] = kwargs["position_ids"].clone(
memory_format=torch.contiguous_format
kwargs["position_ids"] = (
kwargs["position_ids"]
.clone(memory_format=torch.contiguous_format)
.to(device=input_ids.device)
)
kwargs["last_n_tokens"] = 1

Expand Down Expand Up @@ -371,14 +373,20 @@ def generate(
for b_seq in block_table
],
dtype=torch.int64,
).to(device=input_ids.device)
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask.to(
device=input_ids.device
)
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
current_tkv_mask = current_tkv_mask + 1
kwargs["current_tkv_mask"] = current_tkv_mask
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
kwargs["current_tkv_mask"] = current_tkv_mask.to(device=input_ids.device)
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64).to(
device=input_ids.device
)

# batch
input_ids = input_ids.clone(memory_format=torch.contiguous_format)
input_ids = input_ids.clone(memory_format=torch.contiguous_format).to(
device=input_ids.device
)
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(kwargs["block_table"], 0)
torch._dynamo.mark_dynamic(kwargs["slot_mapping"], 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/testing/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook):
model.reset_parameters()

seq_length = 64
batch_size = 8
batch_size = 16
max_new_tokens = 128

# prepare input_ids
Expand Down