diff --git a/autoparallel/api.py b/autoparallel/api.py index e5da5d67..70e65dbe 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -60,6 +60,8 @@ def _get_decomp_table(): decomp_table.pop(torch.ops.aten.native_layer_norm.default) decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) + decomp_table.pop(torch.ops.aten._softmax_backward_data.default) + decomp_table.pop(torch.ops.aten._softmax.default) # decompose addmm to allow for TP on mm decomp_table.pop(torch.ops.aten.addmm.default)