Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ def setup_jax_extension(

# Header files
include_dirs = get_cuda_include_dirs()
cudnn_frontend_include_dir = None
for base_path in (Path(common_header_files), *Path(common_header_files).parents):
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is None:
for base_path in Path(__file__).resolve().parents:
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is not None:
include_dirs.append(cudnn_frontend_include_dir)
include_dirs.extend(
[
common_header_files,
Expand Down
1 change: 1 addition & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_score_mod.xml $TE_PATH/tests/jax/test_fused_attn_score_mod.py || test_fail "tests/jax/test_fused_attn_score_mod.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
Expand Down
147 changes: 145 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs_for_attn,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from test_fused_attn import (
FusedAttnRunner,
BiasShape,
SeqDescFormat,
customcall_fused_dpa,
)
from test_fused_attn_score_mod import (
_ScoreModSoftcap,
_has_cudnn_frontend_python,
_reference_attention,
_require_cudnn_frontend_score_mod,
)
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax import autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand Down Expand Up @@ -272,6 +285,136 @@ def test_cross_attn(
runner.test_backward()


DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
}
Comment thread
vcherepanov-nv marked this conversation as resolved.


@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required")
class TestDistributedScoreModSelfAttn:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_softcap_score_mod_with_aux_params_backward(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
):
_require_cudnn_frontend_score_mod()
batch, seqlen, num_heads, head_dim = data_shape
Comment thread
vcherepanov-nv marked this conversation as resolved.
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource

if dp_axis is not None:
dp_size = mesh_shape[mesh_axes.index(dp_axis)]
if batch % dp_size != 0:
pytest.skip(f"{batch=} must be divisible by {dp_size=}")
if tp_axis is not None:
tp_size = mesh_shape[mesh_axes.index(tp_axis)]
if num_heads % tp_size != 0:
pytest.skip(f"{num_heads=} must be divisible by {tp_size=}")

runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_heads,
num_heads,
head_dim,
head_dim,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
AttnSoftmaxType.VANILLA_SOFTMAX,
0.0,
dtype,
True,
QKVLayout.BSHD_BSHD_BSHD,
None,
None,
SeqDescFormat.Mask,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
)
runner._setup_inputs()

qkv_sharding = NamedSharding(runner.mesh, PartitionSpec(dp_axis, None, tp_axis, None))
query = (0.125 * runner.q).astype(dtype)
key_tensor = (0.125 * runner.k).astype(dtype)
value = (0.125 * runner.v).astype(dtype)
doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype)

scaling_factor = runner.scaling_factor
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = customcall_fused_dpa(
q,
k,
v,
None,
None,
None,
None,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.NO_MASK,
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor=scaling_factor,
dropout_probability=0.0,
is_training=True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
with runner.mesh, autocast(mesh_resource=mesh_resource):
(score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args)
(ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput)

assert score_mod_out.sharding == qkv_sharding
for grad in score_mod_grads:
assert grad.sharding == qkv_sharding

assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2)
assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2)
for grad, ref_grad in zip(score_mod_grads, ref_grads):
assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2)


DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
Expand Down
Loading
Loading