44import pandas as pd
55import scipy .sparse as sp
66
7+ try :
8+ import torch
9+ except :
10+ pass
11+
12+
713modALinput = Union [sp .csr_matrix , pd .DataFrame , np .ndarray , list ]
814
915
@@ -26,7 +32,13 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
2632 elif isinstance (blocks [0 ], list ):
2733 return np .concatenate (blocks ).tolist ()
2834
29- raise TypeError ('%s datatype is not supported' % type (blocks [0 ]))
35+ try :
36+ if torch .is_tensor (blocks [0 ]):
37+ return torch .cat (blocks )
38+ except :
39+ pass
40+
41+ raise TypeError ("%s datatype is not supported" % type (blocks [0 ]))
3042
3143
3244def data_hstack (blocks : Sequence [modALinput ]) -> modALinput :
@@ -48,7 +60,13 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
4860 elif isinstance (blocks [0 ], list ):
4961 return np .hstack (blocks ).tolist ()
5062
51- TypeError ('%s datatype is not supported' % type (blocks [0 ]))
63+ try :
64+ if torch .is_tensor (blocks [0 ]):
65+ return torch .cat (blocks , dim = 1 )
66+ except :
67+ pass
68+
69+ TypeError ("%s datatype is not supported" % type (blocks [0 ]))
5270
5371
5472def add_row (X : modALinput , row : modALinput ):
@@ -68,8 +86,9 @@ def add_row(X: modALinput, row: modALinput):
6886 return data_vstack ([X , row ])
6987
7088
71- def retrieve_rows (X : modALinput ,
72- I : Union [int , List [int ], np .ndarray ]) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
89+ def retrieve_rows (
90+ X : modALinput , I : Union [int , List [int ], np .ndarray ]
91+ ) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
7392 """
7493 Returns the rows I from the data set X
7594
@@ -78,34 +97,34 @@ def retrieve_rows(X: modALinput,
7897 * pandas series in case of a pandas data frame
7998 * row in case of list or numpy format
8099 """
81- if sp .issparse (X ):
82- # Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
83- # sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
84- # and sp.dia_matrix don't support indexing and need to be converted to a sparse format
85- # that does support indexing. It seems conversion to CSR is currently most efficient.
86-
87- try :
88- return X [I ]
89- except :
90- sp_format = X .getformat ()
91- return X .tocsr ()[I ].asformat (sp_format )
92- elif isinstance (X , pd .DataFrame ):
93- return X .iloc [I ]
94- elif isinstance (X , list ):
95- return np .array (X )[I ].tolist ()
96- elif isinstance (X , dict ):
97- X_return = {}
98- for key , value in X .items ():
99- X_return [key ] = retrieve_rows (value , I )
100- return X_return
101- elif isinstance (X , np .ndarray ):
102- return X [I ]
103-
104- raise TypeError ('%s datatype is not supported' % type (X ))
105100
101+ try :
102+ return X [I ]
103+ except :
104+ if sp .issparse (X ):
105+ # Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
106+ # sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
107+ # and sp.dia_matrix don't support indexing and need to be converted to a sparse format
108+ # that does support indexing. It seems conversion to CSR is currently most efficient.
106109
107- def drop_rows (X : modALinput ,
108- I : Union [int , List [int ], np .ndarray ]) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
110+ sp_format = X .getformat ()
111+ return X .tocsr ()[I ].asformat (sp_format )
112+ elif isinstance (X , pd .DataFrame ):
113+ return X .iloc [I ]
114+ elif isinstance (X , list ):
115+ return np .array (X )[I ].tolist ()
116+ elif isinstance (X , dict ):
117+ X_return = {}
118+ for key , value in X .items ():
119+ X_return [key ] = retrieve_rows (value , I )
120+ return X_return
121+
122+ raise TypeError ("%s datatype is not supported" % type (X ))
123+
124+
125+ def drop_rows (
126+ X : modALinput , I : Union [int , List [int ], np .ndarray ]
127+ ) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
109128 """
110129 Returns X without the row(s) at index/indices I
111130 """
@@ -120,7 +139,13 @@ def drop_rows(X: modALinput,
120139 elif isinstance (X , list ):
121140 return np .delete (X , I , axis = 0 ).tolist ()
122141
123- raise TypeError ('%s datatype is not supported' % type (X ))
142+ try :
143+ if torch .is_tensor (blocks [0 ]):
144+ return torch .cat (blocks )
145+ except :
146+ X [[True if row not in I else False for row in range (X .size (0 ))]]
147+
148+ raise TypeError ("%s datatype is not supported" % type (X ))
124149
125150
126151def enumerate_data (X : modALinput ):
@@ -141,17 +166,18 @@ def enumerate_data(X: modALinput):
141166 # numpy arrays and lists can readily be enumerated
142167 return enumerate (X )
143168
144- raise TypeError (' %s datatype is not supported' % type (X ))
169+ raise TypeError (" %s datatype is not supported" % type (X ))
145170
146171
147172def data_shape (X : modALinput ):
148173 """
149174 Returns the shape of the data set X
150175 """
151- if sp . issparse ( X ) or isinstance ( X , pd . DataFrame ) or isinstance ( X , np . ndarray ) :
152- # scipy.sparse, pandas and numpy all support .shape
176+ try :
177+ # scipy.sparse, torch, pandas and numpy all support .shape
153178 return X .shape
154- elif isinstance (X , list ):
155- return np .array (X ).shape
179+ except :
180+ if isinstance (X , list ):
181+ return np .array (X ).shape
156182
157- raise TypeError (' %s datatype is not supported' % type (X ))
183+ raise TypeError (" %s datatype is not supported" % type (X ))
0 commit comments