Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ experimental/grouped_convolution_tile_instances/instances/*
!experimental/grouped_convolution_tile_instances/instances/*.in
!experimental/grouped_convolution_tile_instances/instances/*.inc
experimental/grouped_convolution_tile_instances/*.inc

# AI agent config
.claude
243 changes: 243 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

Composable Kernel (CK) is AMD's high-performance GPU kernel library for machine learning workloads. It uses HIP C++ with a tile-based programming model and tensor coordinate transformation techniques.

**Two implementations exist:**
- **CK Tile** (`include/ck_tile/`) - Modern tile-programming API, preferred for new development
- **Legacy CK** (`include/ck/`) - Older implementation, still supported

## Build Commands

```bash
# Development build (from build directory)
mkdir build && cd build
../script/cmake-ck-dev.sh .. "gfx908;gfx90a;gfx942"
make -j32 # Use ~2GB RAM per thread

# Build specific targets
make tile_example_fmha_fwd # FMHA forward example
make tile_example_fmha_bwd # FMHA backward example
make ckProfiler # Performance profiler

# Tests
make smoke # Quick tests (<30s each)
make regression # Long tests (>=30s each)
make check # All tests

# Single test
ctest -R "fmha" -V
```

**CMake options:**
- `GPU_TARGETS="gfx908;gfx90a;gfx942"` - Target GPU architectures
- `DTYPES="fp32;fp16;fp8;bf16;int8"` - Data types to build
- `BUILD_DEV=ON` - Development mode

## Code Formatting

CK uses clang-format. Install pre-commit hooks:
```bash
sudo script/install_precommit.sh
```

Disable temporarily with `git commit --no-verify`.

## Architecture

### Four-Layer Structure
1. **Templated Tile Operators** - Low-level tile operations
2. **Templated Kernel/Invoker** - Kernel templates with tile operators
3. **Instantiated Kernel/Invoker** - Concrete kernel instances
4. **Client API** - User-facing API

### Key Directories
- `include/ck_tile/core/` - Core utilities (containers, data types, coordinate transforms)
- `include/ck_tile/ops/` - Operator implementations (fmha, gemm, softmax, etc.)
- `include/ck_tile/ops/fmha/pipeline/` - FMHA pipeline implementations (performance-critical)
- `example/ck_tile/` - Working examples with build recipes
- `codegen/` - Python-based kernel code generation
- `profiler/` - Performance profiling tools

### FMHA (Flash Attention) Architecture

#### Directory Structure
```
include/ck_tile/ops/fmha/
├── kernel/ # Kernel entry points (fmha_fwd_kernel.hpp, fmha_fwd_v3_kernel.hpp)
├── pipeline/ # Pipeline implementations (performance-critical)
├── block/ # Block-level components (masking, dropout, position encoding)
└── api/ # High-level API wrappers
```

#### Kernel Template Structure

The kernel (`FmhaFwdKernel`, `FmhaFwdV3Kernel`) has two key template parameters:
- `FmhaPipeline` - Block tile pipeline handling Q*K and P*V computations
- `EpiloguePipeline` - Post-processing and output storage

**Key data types extracted from pipeline:**
- `QDataType`, `KDataType`, `VDataType` - Input types (fp8, fp16, bf16)
- `PDataType` - Attention probability type after softmax
- `SaccDataType` - Scratch accumulator (typically float)
- `ODataType` - Output type

**Configuration flags:**
- `kIsGroupMode` - Variable-length sequences via seqstart pointers
- `kPadSeqLenQ/K`, `kPadHeadDimQ/V` - Padding control
- `kHasLogitsSoftCap` - Gemma-style logits softcap
- `kStoreLSE` - Store log-sum-exp for backward pass
- `QScaleEnum` - FP8 quantization (PERTENSOR, NONE)

#### Pipeline Implementations

| Pipeline | Name | Description |
|----------|------|-------------|
| `BlockFmhaPipelineQRKSVS` | "qr" | LDS-based, all QKV in LDS. For medium sequences. |
| `BlockFmhaPipelineQRKSVSAsync` | "qr_async" | Q in registers, async K/V loading. For longer sequences. |
| `BlockFmhaFwdV3Pipeline` | "v3" | Next-gen with warp group coordination and instruction scheduling. |
| `BlockFmhaPipelineSplitKV` | - | Multi-pass with reduction for very long sequences. |
| `BlockFmhaPipelinePagedKV` | - | KV-cache paging for inference. |

#### Attention Computation Flow (Online Softmax)

```
Phase 1: GEMM0 (Q × K^T → S)
├── Load Q tile (M0 × K0) into registers
├── Loop over K tiles (N0 × K0):
│ ├── Async load K tile to LDS
│ ├── Sync barrier
│ └── Block GEMM with MFMA → S accumulator
└── Apply scale: S *= 1/sqrt(hdim)

Phase 2: Online Softmax
├── Row-wise max: m_j = max(S_j)
├── Optional: logits softcap (tanh transform)
├── Exponential: P = exp(S - m_j)
├── Row-wise sum: l_j = sum(P_j)
└── Rescale accumulator: O *= exp(m_old - m_new)

Phase 3: GEMM1 (P × V → O)
├── Convert P to compute type
├── Load V tiles (K1 × N1)
├── Block GEMM with MFMA → O accumulator
└── Finalize: O /= l_j

Phase 4: Epilogue
├── Convert O to output type
├── Optional: store LSE = m/log(2) + log(l)
└── Write O tile to DRAM
```

#### Memory Management

**LDS Layout:**
- K tiles: N0 × K0, double-buffered for async prefetch
- V tiles: K1 × N1, bank-conflict-aware padding
- Size computed via `Policy::GetSmemSize<Problem>()`

**Async Copy Pattern:**
```cpp
async_load_tile_raw(k_lds_window, k_dram_window); // Non-blocking
move_tile_window(k_dram_window, {kN0, 0});
// ... GEMM computation overlaps with load ...
s_waitcnt_vmcnt<0>(); // Wait before use
```

**Prefetching Strategy:** Load K[i+1] while computing with K[i]

#### Block-Level Components

**Masking (`block_masking.hpp`):**
- `MASK_FROM_TOP_LEFT` - Causal (lower triangular)
- `MASK_FROM_BOTTOM_RIGHT` - Future tokens
- Local attention via `window_size_left/right`
- `GenericAttentionMask::GetTileRangeAlongX()` - Skip fully masked tiles

**Quantization (`block_attention_quant_scale_enum.hpp`):**
- `NONE` - Standard float operations
- `PERTENSOR` - Single scale per Q/K/V tensor (FP8)
- Flow: `Q_fp8 * scale → float → compute → saturate → O_fp8`

#### Policy/Trait Configuration

**TileFmhaTraits** - Core configuration:
```cpp
template <
bool kPadSeqLenQ, kPadSeqLenK,
bool kPadHeadDimQ, kPadHeadDimV,
bool kHasLogitsSoftCap,
BlockAttentionBiasEnum BiasEnum,
bool kStoreLSE,
bool kHasDropout,
BlockAttentionQuantScaleEnum QScaleEnum
>
struct TileFmhaTraits;
```

**Default Policy** provides:
- Alignment hints for DRAM loads
- GEMM configurations (MFMA instruction selection)
- LDS store/load descriptors
- Register tile distributions

#### Grid/Block Organization

```cpp
dim3 GridSize(batch_size, nhead, ceil(max_seqlen_q / kM0) * ceil(hdim_v / kN1));
dim3 BlockSize(kBlockSize); // Typically 256-512 threads
```

#### V3 Pipeline Optimizations

- **Warp Group Specialization** - 2 warp groups (4 waves each) with different roles
- **Phase Scheduling** - Explicit barriers for MFMA/VALU/TRANS timing
- **Packed FP32** - `v_pk_mul_f32` for two operations per instruction
- **Fast Exp2** - Bit manipulation approximation

## Key Concepts

- **Tile** - Fixed-size data chunk processed by a thread block
- **Block Tile** - Tile owned by entire thread block
- **Wave Tile** - Tile owned by a wavefront (64 threads on AMD)
- **LDS** - Local Data Share (AMD's shared memory)
- **MFMA** - Matrix Fused Multiply-Add (AMD's matrix core instruction)
- **XDL** - Crosslane Data Layout instructions

See `TERMINOLOGY.md` and `ACRONYMS.md` for complete references.

## Common Variable Naming

| Symbol | Meaning |
|--------|---------|
| M, N, K | GEMM dimensions: A[M,K] × B[K,N] = C[M,N] |
| Q, K, V | Query, Key, Value (attention) |
| S | Sequence length |
| D | Head dimension |
| B | Batch size |
| H | Number of attention heads |

## Running FMHA Examples

```bash
# Basic FMHA forward
./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128

# With FP8
./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -prec=fp8

# Group mode (variable length)
./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -d=128

# With causal mask
./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -mask=t
```

Use `-?` flag to see all options.

## Codegen System

Kernels are instantiated into separate files via Python scripts in `codegen/` to enable parallel compilation. Example: `example/ck_tile/01_fmha/codegen/generate.py`.
Empty file added GEMINI.md
Empty file.
46 changes: 27 additions & 19 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,21 @@
((0 < args.window_size_left) or (0 < args.window_size_right));
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
((false) or
((traits.data_type.compare("fp8bf16") == 0) and
(traits.qscale_type == quant_scale_enum::pertensor))) and
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if(traits.data_type.compare("fp8bf16") == 0)
printf("[POYENC] can_dispatch_v3(=%d) -> ", can_dispatch_v3);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
if(traits.data_type.compare("fp8bf16") == 0)
printf("dispatch to fmha_fwd_v3()\\n");
return fmha_fwd_v3(traits, args, config);
}} else {{
if(traits.data_type.compare("fp8bf16") == 0)
printf("dispatch to fmha_fwd_v2()\\n");
return fmha_fwd_v2(traits, args, config);
}}
}}
Expand Down Expand Up @@ -934,27 +941,14 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
} # fmt: skip
elif dtype in cls._DT_FP16_BF16:
return {
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16:
return {
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
elif dtype in cls._DT_FP8FP32:
return {
Expand Down Expand Up @@ -1048,6 +1042,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
elif dtype in cls._DT_FP8BF16:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
return result

@classmethod
Expand All @@ -1059,6 +1057,7 @@ def get_pipelines(
)
if dtype in cls._DT_FP16_BF16:
qscale = "no"
"""
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
Expand All @@ -1084,7 +1083,16 @@ def get_pipelines(
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip

"""
elif dtype in cls._DT_FP8BF16:
# no need lse/dropout kernels
# qr_async_trload_v3 only supports (generic) causal mask
for logits, qscale, mask in itertools.product(
["t", "f"],
["no", "pertensor"],
["no", "causal"],
):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines


Expand Down Expand Up @@ -1365,8 +1373,8 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool:
FMHA_FWD_API_FOOTER_TEMPLATE.format(
F_is_v3_enabled=BOOL_MAP[
# NOTE: enable v3 pipelines when ready
# 0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
False
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
# False
]
),
]
Expand Down
6 changes: 6 additions & 0 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqstart_q_ptr,
Expand Down Expand Up @@ -771,6 +774,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
Expand Down
3 changes: 3 additions & 0 deletions example/ck_tile/01_fmha/mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ struct mask_info
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
tmp.left = -1;
tmp.right = -1;
tmp.sink = 0;
}
else if(str == "1" || str == "t")
{
Expand Down
Loading
Loading