2525"""
2626
2727class lineage_population_model ():
28- def __init__ (self , mode = "cpu" ):
29- self .mode = mode
28+ def __init__ (self , device = "cpu" ):
29+
30+ self .device = device
3031 self .model = models .resnet18 (pretrained = True )
3132 self .model .fc = nn .Linear (512 , 7 ) ## resize last layer
3233 self .model_dir = os .path .dirname (__file__ )
@@ -39,14 +40,14 @@ def __init__(self, mode = "cpu"):
3940
4041 try :
4142 # print("model already downloaded, loading model...")
42- self .model .load_state_dict (torch .load (self .model_dir + "/" + self .model_name , map_location = "cpu" ))
43+ self .model .load_state_dict (torch .load (self .model_dir + "/" + self .model_name , map_location = self . device ))
4344 except :
4445 print ("model not found, downloading from:" , self .model_url )
4546 filename = wget .download (self .model_url , out = self .model_dir )
4647 # print(filename)
47- self .model .load_state_dict (torch .load (self .model_dir + "/" + self .model_name , map_location = "cpu" ))
48-
48+ self .model .load_state_dict (torch .load (self .model_dir + "/" + self .model_name , map_location = self .device ))
4949
50+ self .model .to (self .device )
5051 self .model .eval ()
5152
5253 self .transforms = transforms .Compose ([
@@ -78,8 +79,7 @@ def predict(self, image_path):
7879
7980 image = cv2 .imread (image_path , 0 )
8081 image = cv2 .cvtColor (image , cv2 .COLOR_GRAY2RGB )
81- tensor = self .transforms (image ).unsqueeze (0 )
82-
82+ tensor = self .transforms (image ).unsqueeze (0 ).to (self .device )
8383 pred = self .model (tensor ).detach ().cpu ().numpy ().reshape (1 ,- 1 )
8484
8585 pred_scaled = (self .scaler .inverse_transform (pred ).flatten ()).astype (np .uint8 )
@@ -144,13 +144,13 @@ def predict_from_video(self, video_path, csv_name = "foo.csv", save_csv = False
144144
145145 if notebook_mode == True :
146146 for i in tqdm_notebook (range (len (images )), desc = 'Predicting from video file: :' ):
147- tensor = self .transforms (images [i ]).unsqueeze (0 )
147+ tensor = self .transforms (images [i ]).unsqueeze (0 ). to ( self . device )
148148 pred = self .model (tensor ).detach ().cpu ().numpy ().reshape (1 ,- 1 )
149149 pred_scaled = (self .scaler .inverse_transform (pred ).flatten ()).astype (np .uint8 )
150150 preds .append (pred_scaled )
151151 else :
152152 for i in tqdm (range (len (images )), desc = 'Predicting from video file: :' ):
153- tensor = self .transforms (images [i ]).unsqueeze (0 )
153+ tensor = self .transforms (images [i ]).unsqueeze (0 ). to ( self . device )
154154 pred = self .model (tensor ).detach ().cpu ().numpy ().reshape (1 ,- 1 )
155155 pred_scaled = (self .scaler .inverse_transform (pred ).flatten ()).astype (np .uint8 )
156156 preds .append (pred_scaled )
@@ -180,7 +180,7 @@ def predict_from_video(self, video_path, csv_name = "foo.csv", save_csv = False
180180
181181
182182
183- def create_population_plot_from_video (self , video_path , save_plot = False , plot_name = "plot.png" , ignore_first_n_frames = 0 , ignore_last_n_frames = 0 ):
183+ def create_population_plot_from_video (self , video_path , save_plot = False , plot_name = "plot.png" , ignore_first_n_frames = 0 , ignore_last_n_frames = 0 , notebook_mode = False ):
184184
185185 """
186186 inputs{
@@ -189,6 +189,7 @@ def create_population_plot_from_video(self, video_path, save_plot = False, plot_
189189 plot_name <str> = filename of the plot image to be saved
190190 ignore_first_n_frames <int> = number of frames to drop in the start of the video
191191 ignore_last_n_frames <int> = number of frames to drop in the end of the video
192+ notebook_mode <bool> = toogle between script(False) and notebook(True), for better user interface
192193 }
193194
194195 outputs{
@@ -198,7 +199,7 @@ def create_population_plot_from_video(self, video_path, save_plot = False, plot_
198199 plots all the predictions from a video into a matplotlib.pyplot
199200
200201 """
201- df = self .predict_from_video (video_path , ignore_first_n_frames = ignore_first_n_frames , ignore_last_n_frames = ignore_last_n_frames )
202+ df = self .predict_from_video (video_path , ignore_first_n_frames = ignore_first_n_frames , ignore_last_n_frames = ignore_last_n_frames , notebook_mode = notebook_mode )
202203
203204 labels = ["A" , "E" , "M" , "P" , "C" , "D" , "Z" ]
204205
0 commit comments