Conversation
e2e6e9f to
69c9679
Compare
johannaSommer
left a comment
There was a problem hiding this comment.
First PR and already almost flawless, big 👏🏻👏🏻👏🏻 coming your way soon!
src/pruna/algorithms/sage_attn.py
Outdated
| runs_on: list[str] = ["cuda", "accelerate"] | ||
| dataset_required: bool = False | ||
| compatible_before: Iterable[str] = [] | ||
| compatible_after: Iterable[str] = ["torch_compile"] |
There was a problem hiding this comment.
compatible after would also be tags.CACHERS and compatible before probably also tags.QUANTIZERS
There was a problem hiding this comment.
then add this compatibility also in other algorithms
src/pruna/algorithms/sage_attn.py
Outdated
| return False | ||
|
|
||
| return any( | ||
| hasattr(component, "set_attention_backend") and component.dtype in [torch.bfloat16, torch.float16] |
There was a problem hiding this comment.
i recall this dtype check for the components from flash attention (because attention needs to be computed in this precision for FA3 to work), did we double check that that is the case also here?
src/pruna/algorithms/sage_attn.py
Outdated
| # We simply apply the sage attention backend from diffusers | ||
| # Furthermore, we use the sage attention kernel from the hub as the default sageattn function | ||
| # is broken (at least at the moment) | ||
| for component in model.components.values(): |
There was a problem hiding this comment.
as discussed, let's add target modules also here :)
src/pruna/algorithms/sage_attn.py
Outdated
| configuration system. | ||
| """ | ||
| return [ | ||
| Boolean( |
There was a problem hiding this comment.
this is actually not needed and we can remove it, as the user can specify this exactly through the target modules anyway (there is a smash config interface for this)
src/pruna/algorithms/sage_attn.py
Outdated
| The wrapped model. | ||
| """ | ||
| target_modules = smash_config["target_modules"] | ||
| exclude_first_and_last_transformer_blocks = smash_config["exclude_first_and_last_transformer_blocks"] |
There was a problem hiding this comment.
for the target modules, let's please use the functionality we already have, otherwise we have a lot of duplicate code here
|
This PR has been inactive for 10 days and is now marked as stale. |
johannaSommer
left a comment
There was a problem hiding this comment.
Just two more comments regarding target modules, then we are gtg! :)
…antizers as compatible after and before, add sage_attn in corresponding cachers and quantizers algorithms as compatible, add dtype check as sage_attn only works for float/bfloat16 (double checked), add target modules (but not fully finished yet)
…ast attention block per attention component. Remove dtype gaurd as dtypes of q, k, and v per attn module is implicitly checked by sage attention kernel.
…s default target module, remove warning print
3ea1b21 to
7b196e6
Compare
|
This PR has been inactive for 10 days and is now marked as stale. |
Description
Integration of the Sage Attention algorithm into the Pruna framework. The current version applies the attention backend from Diffusers, choosing the Sage Attention kernel from the Kernel Hub. This is because the original sageattn function appears to be broken (its outputs were pure noise). Additionally, tests for the Sage Attention algorithm were implemented.
Related Issue
No issues were fixed.
Type of Change
How Has This Been Tested?
Reuse of the tests for flashattn3 adapted to sage attention.
Checklist
Additional Notes
/