diff --git a/.gitignore b/.gitignore index 740d5464fb9..cee6233a887 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ experimental/grouped_convolution_tile_instances/instances/* !experimental/grouped_convolution_tile_instances/instances/*.in !experimental/grouped_convolution_tile_instances/instances/*.inc experimental/grouped_convolution_tile_instances/*.inc + +# AI agent config +.claude diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000000..672ce56b8dc --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,243 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Composable Kernel (CK) is AMD's high-performance GPU kernel library for machine learning workloads. It uses HIP C++ with a tile-based programming model and tensor coordinate transformation techniques. + +**Two implementations exist:** +- **CK Tile** (`include/ck_tile/`) - Modern tile-programming API, preferred for new development +- **Legacy CK** (`include/ck/`) - Older implementation, still supported + +## Build Commands + +```bash +# Development build (from build directory) +mkdir build && cd build +../script/cmake-ck-dev.sh .. "gfx908;gfx90a;gfx942" +make -j32 # Use ~2GB RAM per thread + +# Build specific targets +make tile_example_fmha_fwd # FMHA forward example +make tile_example_fmha_bwd # FMHA backward example +make ckProfiler # Performance profiler + +# Tests +make smoke # Quick tests (<30s each) +make regression # Long tests (>=30s each) +make check # All tests + +# Single test +ctest -R "fmha" -V +``` + +**CMake options:** +- `GPU_TARGETS="gfx908;gfx90a;gfx942"` - Target GPU architectures +- `DTYPES="fp32;fp16;fp8;bf16;int8"` - Data types to build +- `BUILD_DEV=ON` - Development mode + +## Code Formatting + +CK uses clang-format. Install pre-commit hooks: +```bash +sudo script/install_precommit.sh +``` + +Disable temporarily with `git commit --no-verify`. + +## Architecture + +### Four-Layer Structure +1. **Templated Tile Operators** - Low-level tile operations +2. **Templated Kernel/Invoker** - Kernel templates with tile operators +3. **Instantiated Kernel/Invoker** - Concrete kernel instances +4. **Client API** - User-facing API + +### Key Directories +- `include/ck_tile/core/` - Core utilities (containers, data types, coordinate transforms) +- `include/ck_tile/ops/` - Operator implementations (fmha, gemm, softmax, etc.) +- `include/ck_tile/ops/fmha/pipeline/` - FMHA pipeline implementations (performance-critical) +- `example/ck_tile/` - Working examples with build recipes +- `codegen/` - Python-based kernel code generation +- `profiler/` - Performance profiling tools + +### FMHA (Flash Attention) Architecture + +#### Directory Structure +``` +include/ck_tile/ops/fmha/ +├── kernel/ # Kernel entry points (fmha_fwd_kernel.hpp, fmha_fwd_v3_kernel.hpp) +├── pipeline/ # Pipeline implementations (performance-critical) +├── block/ # Block-level components (masking, dropout, position encoding) +└── api/ # High-level API wrappers +``` + +#### Kernel Template Structure + +The kernel (`FmhaFwdKernel`, `FmhaFwdV3Kernel`) has two key template parameters: +- `FmhaPipeline` - Block tile pipeline handling Q*K and P*V computations +- `EpiloguePipeline` - Post-processing and output storage + +**Key data types extracted from pipeline:** +- `QDataType`, `KDataType`, `VDataType` - Input types (fp8, fp16, bf16) +- `PDataType` - Attention probability type after softmax +- `SaccDataType` - Scratch accumulator (typically float) +- `ODataType` - Output type + +**Configuration flags:** +- `kIsGroupMode` - Variable-length sequences via seqstart pointers +- `kPadSeqLenQ/K`, `kPadHeadDimQ/V` - Padding control +- `kHasLogitsSoftCap` - Gemma-style logits softcap +- `kStoreLSE` - Store log-sum-exp for backward pass +- `QScaleEnum` - FP8 quantization (PERTENSOR, NONE) + +#### Pipeline Implementations + +| Pipeline | Name | Description | +|----------|------|-------------| +| `BlockFmhaPipelineQRKSVS` | "qr" | LDS-based, all QKV in LDS. For medium sequences. | +| `BlockFmhaPipelineQRKSVSAsync` | "qr_async" | Q in registers, async K/V loading. For longer sequences. | +| `BlockFmhaFwdV3Pipeline` | "v3" | Next-gen with warp group coordination and instruction scheduling. | +| `BlockFmhaPipelineSplitKV` | - | Multi-pass with reduction for very long sequences. | +| `BlockFmhaPipelinePagedKV` | - | KV-cache paging for inference. | + +#### Attention Computation Flow (Online Softmax) + +``` +Phase 1: GEMM0 (Q × K^T → S) +├── Load Q tile (M0 × K0) into registers +├── Loop over K tiles (N0 × K0): +│ ├── Async load K tile to LDS +│ ├── Sync barrier +│ └── Block GEMM with MFMA → S accumulator +└── Apply scale: S *= 1/sqrt(hdim) + +Phase 2: Online Softmax +├── Row-wise max: m_j = max(S_j) +├── Optional: logits softcap (tanh transform) +├── Exponential: P = exp(S - m_j) +├── Row-wise sum: l_j = sum(P_j) +└── Rescale accumulator: O *= exp(m_old - m_new) + +Phase 3: GEMM1 (P × V → O) +├── Convert P to compute type +├── Load V tiles (K1 × N1) +├── Block GEMM with MFMA → O accumulator +└── Finalize: O /= l_j + +Phase 4: Epilogue +├── Convert O to output type +├── Optional: store LSE = m/log(2) + log(l) +└── Write O tile to DRAM +``` + +#### Memory Management + +**LDS Layout:** +- K tiles: N0 × K0, double-buffered for async prefetch +- V tiles: K1 × N1, bank-conflict-aware padding +- Size computed via `Policy::GetSmemSize()` + +**Async Copy Pattern:** +```cpp +async_load_tile_raw(k_lds_window, k_dram_window); // Non-blocking +move_tile_window(k_dram_window, {kN0, 0}); +// ... GEMM computation overlaps with load ... +s_waitcnt_vmcnt<0>(); // Wait before use +``` + +**Prefetching Strategy:** Load K[i+1] while computing with K[i] + +#### Block-Level Components + +**Masking (`block_masking.hpp`):** +- `MASK_FROM_TOP_LEFT` - Causal (lower triangular) +- `MASK_FROM_BOTTOM_RIGHT` - Future tokens +- Local attention via `window_size_left/right` +- `GenericAttentionMask::GetTileRangeAlongX()` - Skip fully masked tiles + +**Quantization (`block_attention_quant_scale_enum.hpp`):** +- `NONE` - Standard float operations +- `PERTENSOR` - Single scale per Q/K/V tensor (FP8) +- Flow: `Q_fp8 * scale → float → compute → saturate → O_fp8` + +#### Policy/Trait Configuration + +**TileFmhaTraits** - Core configuration: +```cpp +template < + bool kPadSeqLenQ, kPadSeqLenK, + bool kPadHeadDimQ, kPadHeadDimV, + bool kHasLogitsSoftCap, + BlockAttentionBiasEnum BiasEnum, + bool kStoreLSE, + bool kHasDropout, + BlockAttentionQuantScaleEnum QScaleEnum +> +struct TileFmhaTraits; +``` + +**Default Policy** provides: +- Alignment hints for DRAM loads +- GEMM configurations (MFMA instruction selection) +- LDS store/load descriptors +- Register tile distributions + +#### Grid/Block Organization + +```cpp +dim3 GridSize(batch_size, nhead, ceil(max_seqlen_q / kM0) * ceil(hdim_v / kN1)); +dim3 BlockSize(kBlockSize); // Typically 256-512 threads +``` + +#### V3 Pipeline Optimizations + +- **Warp Group Specialization** - 2 warp groups (4 waves each) with different roles +- **Phase Scheduling** - Explicit barriers for MFMA/VALU/TRANS timing +- **Packed FP32** - `v_pk_mul_f32` for two operations per instruction +- **Fast Exp2** - Bit manipulation approximation + +## Key Concepts + +- **Tile** - Fixed-size data chunk processed by a thread block +- **Block Tile** - Tile owned by entire thread block +- **Wave Tile** - Tile owned by a wavefront (64 threads on AMD) +- **LDS** - Local Data Share (AMD's shared memory) +- **MFMA** - Matrix Fused Multiply-Add (AMD's matrix core instruction) +- **XDL** - Crosslane Data Layout instructions + +See `TERMINOLOGY.md` and `ACRONYMS.md` for complete references. + +## Common Variable Naming + +| Symbol | Meaning | +|--------|---------| +| M, N, K | GEMM dimensions: A[M,K] × B[K,N] = C[M,N] | +| Q, K, V | Query, Key, Value (attention) | +| S | Sequence length | +| D | Head dimension | +| B | Batch size | +| H | Number of attention heads | + +## Running FMHA Examples + +```bash +# Basic FMHA forward +./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128 + +# With FP8 +./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -prec=fp8 + +# Group mode (variable length) +./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -d=128 + +# With causal mask +./bin/tile_example_fmha_fwd -b=1 -h=8 -s=4096 -d=128 -mask=t +``` + +Use `-?` flag to see all options. + +## Codegen System + +Kernels are instantiated into separate files via Python scripts in `codegen/` to enable parallel compilation. Example: `example/ck_tile/01_fmha/codegen/generate.py`. diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b3..3885f23a66c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -210,14 +210,21 @@ ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + ((false) or + ((traits.data_type.compare("fp8bf16") == 0) and + (traits.qscale_type == quant_scale_enum::pertensor))) and traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); + if(traits.data_type.compare("fp8bf16") == 0) + printf("[POYENC] can_dispatch_v3(=%d) -> ", can_dispatch_v3); if ({F_is_v3_enabled} and can_dispatch_v3) {{ + if(traits.data_type.compare("fp8bf16") == 0) + printf("dispatch to fmha_fwd_v3()\\n"); return fmha_fwd_v3(traits, args, config); }} else {{ + if(traits.data_type.compare("fp8bf16") == 0) + printf("dispatch to fmha_fwd_v2()\\n"); return fmha_fwd_v2(traits, args, config); }} }} @@ -934,27 +941,14 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: } # fmt: skip elif dtype in cls._DT_FP16_BF16: return { - ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { - ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip elif dtype in cls._DT_FP8FP32: return { @@ -1048,6 +1042,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if (128, 128) in result.keys(): result[(128, 128)].append( FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1059,6 +1057,7 @@ def get_pipelines( ) if dtype in cls._DT_FP16_BF16: qscale = "no" + """ for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -1084,7 +1083,16 @@ def get_pipelines( for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip - + """ + elif dtype in cls._DT_FP8BF16: + # no need lse/dropout kernels + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines @@ -1365,8 +1373,8 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: FMHA_FWD_API_FOOTER_TEMPLATE.format( F_is_v3_enabled=BOOL_MAP[ # NOTE: enable v3 pipelines when ready - # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - False + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + # False ] ), ] diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fdd720fd75b..d5b06d6e5a3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -738,6 +738,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -771,6 +774,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index f85b811116b..010224cb0ea 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -148,6 +148,9 @@ struct mask_info else if(str == "0") { tmp.type = mask_enum::no_mask; + tmp.left = -1; + tmp.right = -1; + tmp.sink = 0; } else if(str == "1" || str == "t") { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 6fe1de634d9..c2e0fe0d4cc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel float logits_soft_cap_rcp; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -640,32 +675,82 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + AttentionVariant variant; const auto variant_params = [&] { if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lse_dram_window, - mask, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); + else + return ck_tile::scales>{scale_o}; + }(); + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..09ca59be420 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -300,7 +300,6 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && - (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); @@ -437,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + // static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -854,12 +853,21 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } - else + else if constexpr(std::is_same_v) { auto casted = detail::cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } + else if constexpr(std::is_same_v) + { + sp(sp_reg_idx).p.thread_buf_[idx] = type_convert(x); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = type_convert(y); + } + else + { + static_assert(false, "unsupported data type for P"); + } }); /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ce097b6741b..2c626dfee12 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -140,9 +140,10 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // Swap NumWarps and LaneGroups to store V in non-swizzled layout in LDS constexpr index_t N0 = NumIssues; - constexpr index_t N1 = LaneGroups; - constexpr index_t N2 = NumWarps; + constexpr index_t N1 = NumWarps; // was LaneGroups + constexpr index_t N2 = LaneGroups; // was NumWarps constexpr index_t K0 = LanesPerK; constexpr index_t K1 = KVector; @@ -150,7 +151,7 @@ struct BlockFmhaV3PipelineDefaultPolicy tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } @@ -239,10 +240,19 @@ struct BlockFmhaV3PipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr auto warp_gemm = [] { + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + /// NOTICE: in order to use load_tile() for K tile with correct row stride, + /// we cannot use WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution here + /// because SwizzleB encoding has half-stride issue for K tile loading + return WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here @@ -307,8 +317,8 @@ struct BlockFmhaV3PipelineDefaultPolicy return BlockGemmARegBRegCRegV2{}; } - static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords - static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + static constexpr ck_tile::index_t kKLdsPadInBytes = 0; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 0; // 16 dwords template CK_TILE_DEVICE static constexpr auto @@ -489,13 +499,13 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( make_tuple(number{}, // n0 - number{}, // n1 number{}, // n2 + number{}, // n1 number{}, // k0 number{}), // k1 make_tuple(number{}, - number{}, number{}, + number{}, number{}, number<1>{}), number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, @@ -510,7 +520,7 @@ struct BlockFmhaV3PipelineDefaultPolicy make_pass_through_transform(number{}), make_merge_transform(make_tuple( number{}, number{}, number{}))), - make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); return v_lds_block_desc_issues_warps_lanes; @@ -558,9 +568,9 @@ struct BlockFmhaV3PipelineDefaultPolicy v_lds_block_desc_0, make_tuple( make_merge_transform( - make_tuple(number{}, number{}, number{})), + make_tuple(number{}, number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); return v_lds_block_desc; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..ff724f41d68 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -288,6 +288,13 @@ using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = + WarpGemmImpl, + 2, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index a7d71d4fa3c..67b6586ea81 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -100,6 +100,8 @@ template<> struct Dispatcher { u // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };