44from itertools import chain
55from pathlib import Path
66from typing import Callable , Dict , List , Tuple , Union
7+ from urllib .parse import urlparse
78
89import numpy as np
910import torch
@@ -150,17 +151,31 @@ def __init__(
150151
151152 # load weights and set devices
152153 if checkpoint_path is not None :
153- ckpt = torch .load (
154- checkpoint_path , map_location = lambda storage , loc : storage
155- )
154+ checkpoint_path = Path (checkpoint_path )
155+ # check if path is url or local and load weigths to memory
156+ if urlparse (checkpoint_path .as_posix ()).scheme :
157+ state_dict = torch .hub .load_state_dict_from_url (checkpoint_path )
158+ else :
159+ state_dict = torch .load (
160+ checkpoint_path , map_location = lambda storage , loc : storage
161+ )
162+
163+ # if the checkpoint is from lightning, the ckpt file contains a lot of other
164+ # stuff than just the state dict.
165+ if "state_dict" in state_dict .keys ():
166+ state_dict = state_dict ["state_dict" ]
156167
168+ # try loading the weights to the model
157169 try :
158- self .model .load_state_dict (ckpt [ " state_dict" ] , strict = True )
170+ msg = self .model .load_state_dict (state_dict , strict = True )
159171 except RuntimeError :
160- new_ckpt = self ._strip_state_dict (ckpt )
161- self .model .load_state_dict (new_ckpt [ "state_dict" ] , strict = True )
172+ new_ckpt = self ._strip_state_dict (state_dict )
173+ msg = self .model .load_state_dict (new_ckpt , strict = True )
162174 except BaseException as e :
163- print (e )
175+ raise RuntimeError (f"Error when loading checkpoint: { e } " )
176+
177+ print (f"Loading weights: { checkpoint_path } for inference." )
178+ print (msg )
164179
165180 assert device in ("cuda" , "cpu" , "mps" )
166181 if device == "cpu" :
@@ -213,6 +228,12 @@ def infer(self, mixed_precision: bool = False) -> None:
213228 `classes_type`, `classes_sem`, `offsets`. See more in the
214229 `FileHandler.save_masks` docs.
215230
231+
232+ Parameters
233+ ----------
234+ mixed_precision : bool, default=False
235+ If True, inference is performed with mixed precision.
236+
216237 Attributes
217238 ----------
218239 - out_masks : Dict[str, Dict[str, np.ndarray]]
@@ -224,11 +245,6 @@ def infer(self, mixed_precision: bool = False) -> None:
224245 The soft masks for each image. I.e. the soft predictions of the trained
225246 model The keys are the image names and the values are dictionaries of
226247 the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
227-
228- Parameters
229- ----------
230- mixed_precision : bool, default=False
231- If True, inference is performed with mixed precision.
232248 """
233249 self .soft_masks = {}
234250 self .out_masks = {}
@@ -291,15 +307,13 @@ def infer(self, mixed_precision: bool = False) -> None:
291307 def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
292308 """Strip te first 'model.' (generated by lightning) from the state dict keys."""
293309 state_dict = OrderedDict ()
294- for k , w in ckpt [ "state_dict" ] .items ():
310+ for k , w in ckpt .items ():
295311 if "num_batches_track" not in k :
296- # new_key = k.strip("model")[1:]
297312 spl = ["" .join (kk ) for kk in k .split ("." )]
298313 new_key = "." .join (spl [1 :])
299314 state_dict [new_key ] = w
300- ckpt ["state_dict" ] = state_dict
301315
302- return ckpt
316+ return state_dict
303317
304318 def _check_and_set_head_args (self ) -> None :
305319 """Check the model has matching head names with the head args and set them."""
0 commit comments