Skip to content

Commit d44009f

Browse files
authored
Support tensor.T for transpose (#1110)
1 parent fc0dcff commit d44009f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def propagate_unary(self, op: ast.unaryop, origin: Origin) -> TypeInfo:
444444

445445
def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
446446
assert origin.key == attr
447-
if attr in {"dtype", "device", "ndim", "shape"}:
447+
if attr in {"dtype", "device", "ndim", "shape", "T"}:
448448
return TypeInfo.from_example(getattr(self.fake_value, attr), origin)
449449
return TensorAttributeType(origin, self)
450450

test/test_views.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9696
_code, result = code_and_output(fn, args)
9797
torch.testing.assert_close(result, args[0] + args[1].transpose(0, 1))
9898

99+
def test_transpose_T_unsqueeze(self):
100+
@helion.kernel(autotune_effort="none")
101+
def fn(x: torch.Tensor) -> torch.Tensor:
102+
out = torch.empty_like(x)
103+
for tile_n, tile_m in hl.tile(x.size()):
104+
tile3d = x[tile_n, tile_m].T.unsqueeze(0)
105+
out[tile_n, tile_m] = tile3d.squeeze(0).T
106+
return out
107+
108+
args = (torch.randn([512, 384], device=DEVICE),)
109+
_, result = code_and_output(fn, args)
110+
torch.testing.assert_close(result, args[0])
111+
99112
@unittest.skipUnless(
100113
supports_tensor_descriptor(), "Tensor descriptor support is required"
101114
)

0 commit comments

Comments
 (0)