Skip to content

Commit 87d1097

Browse files
Merge pull request #61 from PanZezhong1725/issue/60
issue/60: to_tensor存储原torch张量,增加INFINI_ROOT默认路径
2 parents 89e49e3 + 26f8ae5 commit 87d1097

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

test/infiniop/libinfiniop/liboperators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from ctypes import c_int, c_int64, c_uint64, Structure, POINTER, c_size_t
66
from .datatypes import *
77
from .devices import *
8+
from pathlib import Path
89

910
Device = c_int
1011
Optype = c_int
1112

12-
INFINI_ROOT = os.environ.get("INFINI_ROOT")
13+
INFINI_ROOT = os.getenv("INFINI_ROOT") or str(Path.home() / ".infini")
1314

1415

1516
class TensorDescriptor(Structure):
@@ -30,9 +31,10 @@ def invalidate(self):
3031

3132

3233
class CTensor:
33-
def __init__(self, desc, data):
34+
def __init__(self, desc, torch_tensor):
3435
self.descriptor = desc
35-
self.data = data
36+
self.torch_tensor_ = torch_tensor
37+
self.data = torch_tensor.data_ptr()
3638

3739

3840
class Handle(Structure):

test/infiniop/libinfiniop/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def to_tensor(tensor, lib):
1919
ndim = tensor.ndimension()
2020
shape = (ctypes.c_size_t * ndim)(*tensor.shape)
2121
strides = (ctypes.c_int64 * ndim)(*(tensor.stride()))
22-
data_ptr = tensor.data_ptr()
2322
# fmt: off
2423
dt = (
2524
InfiniDtype.I8 if tensor.dtype == torch.int8 else
@@ -46,7 +45,7 @@ def to_tensor(tensor, lib):
4645
ctypes.byref(tensor_desc), ndim, shape, strides, dt
4746
)
4847
# Create Tensor
49-
return CTensor(tensor_desc, data_ptr)
48+
return CTensor(tensor_desc, tensor)
5049

5150

5251
def create_workspace(size, torch_device):

0 commit comments

Comments
 (0)