Skip to content

Commit 2c066d3

Browse files
authored
fix #4 : add gpu support (#34)
1 parent c4063c7 commit 2c066d3

File tree

4 files changed

+29
-28
lines changed

4 files changed

+29
-28
lines changed

devolearn/embryo_generator_model/embryo_generator_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def forward(self, input):
6565

6666

6767
class embryo_generator_model():
68-
def __init__(self, mode = "cpu"):
68+
def __init__(self, device = "cpu"):
6969

7070
"""
7171
ngf = size of feature maps in generator
@@ -74,11 +74,11 @@ def __init__(self, mode = "cpu"):
7474
Do not tweak these unless you're changing the Generator() with a new model with a different architecture.
7575
7676
"""
77-
77+
self.device = device
7878
self.ngf = 128 ## generated image size
7979
self.nz = 128
8080
self.nc = 1
81-
self.generator= Generator(self.ngf, self.nz, self.nc)
81+
self.generator = Generator(self.ngf, self.nz, self.nc)
8282
self.model_url = "https://raw.githubusercontent.com/DevoLearn/devolearn/master/devolearn/embryo_generator_model/embryo_generator.pth"
8383
self.model_name = "embryo_generator.pth"
8484
self.model_dir = os.path.dirname(__file__)
@@ -87,15 +87,16 @@ def __init__(self, mode = "cpu"):
8787

8888
try:
8989
# print("model already downloaded, loading model...")
90-
self.generator.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu"))
90+
self.generator.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= self.device))
9191
except:
9292
print("model not found, downloading from: ", self.model_url)
9393
if os.path.isdir(self.model_dir) == False:
9494
os.mkdir(self.model_dir)
9595
filename = wget.download(self.model_url, out= self.model_dir)
9696
# print(filename)
97-
self.generator.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu"))
98-
97+
self.generator.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= self.device))
98+
99+
self.generator.to(self.device)
99100

100101

101102

@@ -118,7 +119,7 @@ def generate(self, image_size = (700,500)):
118119
generated image to the desired size.
119120
"""
120121
with torch.no_grad():
121-
noise = torch.randn([1,128,1,1])
122+
noise = torch.randn([1,128,1,1]).to(self.device)
122123
im = self.generator(noise)[0][0].cpu().detach().numpy()
123124
im = cv2.resize(im, image_size)
124125
im = 255 - cv2.convertScaleAbs(im, alpha=(255.0)) ## temporary fix against inverted images

devolearn/embryo_segmentor/embryo_segmentor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,18 @@ def generate_centroid_image(thresh):
6363
return centroid_image, centroids
6464

6565
class embryo_segmentor():
66-
def __init__(self):
66+
def __init__(self, device = "cpu"):
6767

6868
"""
6969
Segments the c. elegans embryo from images/videos,
7070
depends on segmentation-models-pytorch for the model backbone
7171
7272
"""
73-
73+
self.device = device
7474
self.ENCODER = 'resnet18'
7575
self.ENCODER_WEIGHTS = 'imagenet'
7676
self.CLASSES = ["nucleus"]
7777
self.ACTIVATION = 'sigmoid'
78-
self.DEVICE = 'cpu'
7978
self.in_channels = 1
8079
self.model_url = "https://github.com/DevoLearn/devolearn/raw/master/devolearn/embryo_segmentor/3d_segmentation_model.pth"
8180
self.model_name = "3d_segmentation_model.pth"
@@ -92,16 +91,16 @@ def __init__(self):
9291

9392
try:
9493
# print("model already downloaded, loading model...")
95-
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu")
94+
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)
9695
except:
9796
print("model not found, downloading from:", self.model_url)
9897
if os.path.isdir(self.model_dir) == False:
9998
os.mkdir(self.model_dir)
10099
filename = wget.download(self.model_url, out= self.model_dir)
101100
# print(filename)
102-
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu")
103-
101+
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)
104102

103+
self.model.to(self.device)
105104
self.model.eval()
106105

107106
self.mini_transform = transforms.Compose([
@@ -129,7 +128,7 @@ def predict(self, image_path, pred_size = (350,250), centroid_mode = False):
129128
"""
130129

131130
im = cv2.imread(image_path,0)
132-
tensor = self.mini_transform(im).unsqueeze(0)
131+
tensor = self.mini_transform(im).unsqueeze(0).to(self.device)
133132
res = self.model(tensor).detach().cpu().numpy()[0][0]
134133
res = cv2.resize(res,pred_size)
135134
if centroid_mode == False:
@@ -165,7 +164,7 @@ def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "p
165164
if notebook_mode == True:
166165
for i in tqdm_notebook(range(len(images)), desc = "saving predictions: "):
167166
save_name = save_folder + "/" + str(i) + ".jpg"
168-
tensor = self.mini_transform(images[i]).unsqueeze(0)
167+
tensor = self.mini_transform(images[i]).unsqueeze(0).to(self.device)
169168
res = self.model(tensor).detach().cpu().numpy()[0][0]
170169

171170
if centroid_mode == True:
@@ -177,7 +176,7 @@ def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "p
177176
else :
178177
for i in tqdm(range(len(images)), desc = "saving predictions: "):
179178
save_name = save_folder + "/" + str(i) + ".jpg"
180-
tensor = self.mini_transform(images[i]).unsqueeze(0)
179+
tensor = self.mini_transform(images[i]).unsqueeze(0).to(self.device)
181180
res = self.model(tensor).detach().cpu().numpy()[0][0]
182181

183182
if centroid_mode == True:

devolearn/lineage_population_model/lineage_population_model.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
"""
2626

2727
class lineage_population_model():
28-
def __init__(self, mode = "cpu"):
29-
self.mode = mode
28+
def __init__(self, device = "cpu"):
29+
30+
self.device = device
3031
self.model = models.resnet18(pretrained = True)
3132
self.model.fc = nn.Linear(512, 7) ## resize last layer
3233
self.model_dir = os.path.dirname(__file__)
@@ -39,14 +40,14 @@ def __init__(self, mode = "cpu"):
3940

4041
try:
4142
# print("model already downloaded, loading model...")
42-
self.model.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu"))
43+
self.model.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= self.device))
4344
except:
4445
print("model not found, downloading from:", self.model_url)
4546
filename = wget.download(self.model_url, out= self.model_dir)
4647
# print(filename)
47-
self.model.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= "cpu"))
48-
48+
self.model.load_state_dict(torch.load(self.model_dir + "/" + self.model_name, map_location= self.device))
4949

50+
self.model.to(self.device)
5051
self.model.eval()
5152

5253
self.transforms = transforms.Compose([
@@ -78,8 +79,7 @@ def predict(self, image_path):
7879

7980
image = cv2.imread(image_path, 0)
8081
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
81-
tensor = self.transforms(image).unsqueeze(0)
82-
82+
tensor = self.transforms(image).unsqueeze(0).to(self.device)
8383
pred = self.model(tensor).detach().cpu().numpy().reshape(1,-1)
8484

8585
pred_scaled = (self.scaler.inverse_transform(pred).flatten()).astype(np.uint8)
@@ -144,13 +144,13 @@ def predict_from_video(self, video_path, csv_name = "foo.csv", save_csv = False
144144

145145
if notebook_mode == True:
146146
for i in tqdm_notebook(range(len(images)), desc='Predicting from video file: :'):
147-
tensor = self.transforms(images[i]).unsqueeze(0)
147+
tensor = self.transforms(images[i]).unsqueeze(0).to(self.device)
148148
pred = self.model(tensor).detach().cpu().numpy().reshape(1,-1)
149149
pred_scaled = (self.scaler.inverse_transform(pred).flatten()).astype(np.uint8)
150150
preds.append(pred_scaled)
151151
else :
152152
for i in tqdm(range(len(images)), desc='Predicting from video file: :'):
153-
tensor = self.transforms(images[i]).unsqueeze(0)
153+
tensor = self.transforms(images[i]).unsqueeze(0).to(self.device)
154154
pred = self.model(tensor).detach().cpu().numpy().reshape(1,-1)
155155
pred_scaled = (self.scaler.inverse_transform(pred).flatten()).astype(np.uint8)
156156
preds.append(pred_scaled)
@@ -180,7 +180,7 @@ def predict_from_video(self, video_path, csv_name = "foo.csv", save_csv = False
180180

181181

182182

183-
def create_population_plot_from_video(self, video_path, save_plot = False, plot_name = "plot.png", ignore_first_n_frames = 0, ignore_last_n_frames = 0 ):
183+
def create_population_plot_from_video(self, video_path, save_plot = False, plot_name = "plot.png", ignore_first_n_frames = 0, ignore_last_n_frames = 0, notebook_mode = False):
184184

185185
"""
186186
inputs{
@@ -189,6 +189,7 @@ def create_population_plot_from_video(self, video_path, save_plot = False, plot_
189189
plot_name <str> = filename of the plot image to be saved
190190
ignore_first_n_frames <int> = number of frames to drop in the start of the video
191191
ignore_last_n_frames <int> = number of frames to drop in the end of the video
192+
notebook_mode <bool> = toogle between script(False) and notebook(True), for better user interface
192193
}
193194
194195
outputs{
@@ -198,7 +199,7 @@ def create_population_plot_from_video(self, video_path, save_plot = False, plot_
198199
plots all the predictions from a video into a matplotlib.pyplot
199200
200201
"""
201-
df = self.predict_from_video(video_path, ignore_first_n_frames = ignore_first_n_frames, ignore_last_n_frames = ignore_last_n_frames )
202+
df = self.predict_from_video(video_path, ignore_first_n_frames = ignore_first_n_frames, ignore_last_n_frames = ignore_last_n_frames, notebook_mode = notebook_mode)
202203

203204
labels = ["A", "E", "M", "P", "C", "D", "Z"]
204205

devolearn/tests/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class test(unittest.TestCase):
2020
def test_lineage_population_model(self):
2121
test_dir = os.path.dirname(__file__)
2222

23-
model = lineage_population_model(mode = "cpu")
23+
model = lineage_population_model(device = "cpu")
2424
pred = model.predict(image_path = test_dir + "/" + "sample_data/images/embryo_sample.png")
2525
self.assertTrue(isinstance(pred, dict), "should be dict")
2626

0 commit comments

Comments
 (0)