1212from __future__ import annotations
1313
1414import logging
15+ import os
1516import warnings
1617from collections .abc import Mapping
1718from typing import TYPE_CHECKING , Any
@@ -118,6 +119,7 @@ def __init__(
118119 self ._key_metric_checkpoint : Checkpoint | None = None
119120 self ._interval_checkpoint : Checkpoint | None = None
120121 self ._name = name
122+ self ._final_filename = final_filename
121123
122124 class _DiskSaver (DiskSaver ):
123125 """
@@ -148,7 +150,7 @@ def _final_func(engine: Engine) -> Any:
148150
149151 self ._final_checkpoint = Checkpoint (
150152 to_save = self .save_dict ,
151- save_handler = _DiskSaver (dirname = self .save_dir , filename = final_filename ),
153+ save_handler = _DiskSaver (dirname = self .save_dir , filename = self . _final_filename ),
152154 filename_prefix = file_prefix ,
153155 score_function = _final_func ,
154156 score_name = "final_iteration" ,
@@ -271,7 +273,11 @@ def completed(self, engine: Engine) -> None:
271273 raise AssertionError
272274 if not hasattr (self .logger , "info" ):
273275 raise AssertionError ("Error, provided logger has not info attribute." )
274- self .logger .info (f"Train completed, saved final checkpoint: { self ._final_checkpoint .last_checkpoint } " )
276+ if self ._final_filename is not None :
277+ _final_checkpoint_path = os .path .join (self .save_dir , self ._final_filename )
278+ else :
279+ _final_checkpoint_path = self ._final_checkpoint .last_checkpoint # type: ignore[assignment]
280+ self .logger .info (f"Train completed, saved final checkpoint: { _final_checkpoint_path } " )
275281
276282 def exception_raised (self , engine : Engine , e : Exception ) -> None :
277283 """Callback for train or validation/evaluation exception raised Event.
@@ -291,7 +297,11 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
291297 raise AssertionError
292298 if not hasattr (self .logger , "info" ):
293299 raise AssertionError ("Error, provided logger has not info attribute." )
294- self .logger .info (f"Exception raised, saved the last checkpoint: { self ._final_checkpoint .last_checkpoint } " )
300+ if self ._final_filename is not None :
301+ _final_checkpoint_path = os .path .join (self .save_dir , self ._final_filename )
302+ else :
303+ _final_checkpoint_path = self ._final_checkpoint .last_checkpoint # type: ignore[assignment]
304+ self .logger .info (f"Exception raised, saved the last checkpoint: { _final_checkpoint_path } " )
295305 raise e
296306
297307 def metrics_completed (self , engine : Engine ) -> None :
0 commit comments