diff --git a/nsight/extraction.py b/nsight/extraction.py index f715301..2eb6cbd 100644 --- a/nsight/extraction.py +++ b/nsight/extraction.py @@ -275,8 +275,9 @@ def extract_df_from_report( for arg_name, arg_values in arg_arrays.items(): df_data[arg_name] = arg_values - # Explode the dataframe - df = pd.DataFrame(df_data).apply(pd.Series.explode).reset_index(drop=True) + # Explode only Value and Metric columns (which contain tuples of per-metric data). + # Other columns (including function args) may also contain tuples that should NOT be exploded. + df = pd.DataFrame(df_data).explode(["Value", "Metric"]).reset_index(drop=True) if derive_metric is not None: transformed_df_data = { @@ -295,7 +296,7 @@ def extract_df_from_report( transformed_df = ( pd.DataFrame(transformed_df_data) - .apply(pd.Series.explode) + .explode(["Value", "Metric"]) .reset_index(drop=True) ) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9ec81e3..566166e 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1231,3 +1231,89 @@ def profiled_func(x: int, y: int) -> None: assert all( df["AvgValue"].notna() & (df["AvgValue"] > 0) ), f"Invalid AvgValue for metric {metrics}" + + +# ============================================================================ +# Tuple-typed function arguments +# ============================================================================ + + +def test_tuple_typed_function_args() -> None: + """Test that tuple-typed function arguments are preserved and not exploded. + + Regression test: apply(pd.Series.explode) was exploding all columns, + including tuple-typed function arguments. When two tuple args had different + lengths, this caused a ValueError. Even with matching lengths, the data + was silently corrupted. + """ + + @nsight.analyze.kernel(output="quiet") + def kernel_with_tuple_args(size: int, mnkl: tuple, tile_shape: tuple) -> None: + a = torch.randn(size, size, device="cuda") + b = torch.randn(size, size, device="cuda") + with nsight.annotate("test_tuple_args"): + _ = a + b + + result = kernel_with_tuple_args( + configs=[(32, (8192, 8192, 8192, 1), (128, 256, 64))] + ) + df = result.to_dataframe() + + # Should have exactly 1 row (1 config, 1 default metric) + assert len(df) == 1, f"Expected 1 row, got {len(df)}" + + # Tuple args should be preserved as tuples, not exploded + assert df["mnkl"].iloc[0] == ( + 8192, + 8192, + 8192, + 1, + ), f"mnkl should be preserved as tuple, got {df['mnkl'].iloc[0]}" + assert df["tile_shape"].iloc[0] == ( + 128, + 256, + 64, + ), f"tile_shape should be preserved as tuple, got {df['tile_shape'].iloc[0]}" + + +def test_tuple_typed_function_args_multiple_metrics() -> None: + """Test that with multiple metrics, only Value/Metric are exploded, not tuple args. + + With 3 metrics and tuple args of different lengths (4 and 2 elements), + we should get exactly 3 rows per config — one per metric — with the + tuple args preserved intact on each row. + """ + metrics = [ + "gpu__time_duration.sum", + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + ] + + @nsight.analyze.kernel(output="quiet", metrics=metrics) + def kernel_with_tuple_args_multi(size: int, mnkl: tuple, pair: tuple) -> None: + a = torch.randn(size, size, device="cuda") + b = torch.randn(size, size, device="cuda") + with nsight.annotate("test_tuple_args_multi"): + _ = a + b + + result = kernel_with_tuple_args_multi( + configs=[(32, (8192, 8192, 8192, 1), (10, 20))] + ) + df = result.to_dataframe() + + # 3 metrics × 1 config = 3 rows + assert len(df) == 3, f"Expected 3 rows (one per metric), got {len(df)}" + + # Each row should have the correct metric + assert set(df["Metric"]) == set(metrics), ( + f"Expected metrics {metrics}, got {df['Metric'].tolist()}" + ) + + # Tuple args should be preserved as tuples on every row + for idx in range(len(df)): + assert df["mnkl"].iloc[idx] == (8192, 8192, 8192, 1), ( + f"Row {idx}: mnkl should be (8192, 8192, 8192, 1), got {df['mnkl'].iloc[idx]}" + ) + assert df["pair"].iloc[idx] == (10, 20), ( + f"Row {idx}: pair should be (10, 20), got {df['pair'].iloc[idx]}" + )