diff --git a/src/vae.hpp b/src/vae.hpp index 01081343b..c627616c2 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock { v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] } - h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled); if (use_linear) { h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] diff --git a/src/wan.hpp b/src/wan.hpp index 7b1059785..90de3bdd1 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -572,8 +572,8 @@ namespace WAN { auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c] + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled); // [t, h * w, c] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]