Skip to content

Commit f5708ea

Browse files
authored
Fix CheckpointSaver log error (#6026)
1 parent bf55f22 commit f5708ea

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

monai/handlers/checkpoint_saver.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import logging
15+
import os
1516
import warnings
1617
from collections.abc import Mapping
1718
from typing import TYPE_CHECKING, Any
@@ -118,6 +119,7 @@ def __init__(
118119
self._key_metric_checkpoint: Checkpoint | None = None
119120
self._interval_checkpoint: Checkpoint | None = None
120121
self._name = name
122+
self._final_filename = final_filename
121123

122124
class _DiskSaver(DiskSaver):
123125
"""
@@ -148,7 +150,7 @@ def _final_func(engine: Engine) -> Any:
148150

149151
self._final_checkpoint = Checkpoint(
150152
to_save=self.save_dict,
151-
save_handler=_DiskSaver(dirname=self.save_dir, filename=final_filename),
153+
save_handler=_DiskSaver(dirname=self.save_dir, filename=self._final_filename),
152154
filename_prefix=file_prefix,
153155
score_function=_final_func,
154156
score_name="final_iteration",
@@ -271,7 +273,11 @@ def completed(self, engine: Engine) -> None:
271273
raise AssertionError
272274
if not hasattr(self.logger, "info"):
273275
raise AssertionError("Error, provided logger has not info attribute.")
274-
self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}")
276+
if self._final_filename is not None:
277+
_final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)
278+
else:
279+
_final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment]
280+
self.logger.info(f"Train completed, saved final checkpoint: {_final_checkpoint_path}")
275281

276282
def exception_raised(self, engine: Engine, e: Exception) -> None:
277283
"""Callback for train or validation/evaluation exception raised Event.
@@ -291,7 +297,11 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
291297
raise AssertionError
292298
if not hasattr(self.logger, "info"):
293299
raise AssertionError("Error, provided logger has not info attribute.")
294-
self.logger.info(f"Exception raised, saved the last checkpoint: {self._final_checkpoint.last_checkpoint}")
300+
if self._final_filename is not None:
301+
_final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)
302+
else:
303+
_final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment]
304+
self.logger.info(f"Exception raised, saved the last checkpoint: {_final_checkpoint_path}")
295305
raise e
296306

297307
def metrics_completed(self, engine: Engine) -> None:

0 commit comments

Comments
 (0)