diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index 4c4f43b80d..c9e3cdbbff 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -316,7 +316,8 @@ def get_values(self, values: list[np.ndarray] | None): class SaveSensitivityGeoH5(SaveArrayGeoH5): """ - Save the model at the current iteration to a geoh5 file. + Save the approximate sensitivities (JtJdiag) at the current + iteration to a geoh5 file. """ _attribute_type = "sensitivities" @@ -337,7 +338,7 @@ def get_values(self, values: list[np.ndarray] | None): class SaveDataGeoH5(SaveArrayGeoH5): """ - Save the model at the current iteration to a geoh5 file. + Save the predicted data at the current iteration to a geoh5 file. """ _attribute_type = "predicted" @@ -384,22 +385,28 @@ def joint_index(self, value: list[int] | None): class SaveLogFilesGeoH5(BaseSaveGeoH5): + """ + Save iteration metrics to log files and attach them to the geoh5 file as bytes. + + :param h5_object: The geoh5 object to which the log files will be attached. + :param base_name: The base name of the log files. + """ + def __init__( self, h5_object, + base_name, **kwargs, ): + self.base_name = base_name super().__init__(h5_object, **kwargs) def write(self, iteration: int, **_): dirpath = Path(self._workspace.h5file).parent - filepath = dirpath / "SimPEG.out" + filepath = dirpath / f"{self.base_name}.out" - if iteration == 0: - with open(filepath, "w", encoding="utf-8") as f: - f.write("iteration beta phi_d phi_m time\n") log = [] - with open(dirpath / "SimPEG.log", "r", encoding="utf-8") as file: + with open(dirpath / f"{self.base_name}.log", "r", encoding="utf-8") as file: iteration = 0 for line in file: val = re.findall(r"[+-]?(?:0|[1-9]\d*)(?:\.\d*)?(?:[eE][+-]?\d+)", line) @@ -408,6 +415,10 @@ def write(self, iteration: int, **_): iteration += 1 if len(log) > 0: + if not filepath.exists(): + with open(filepath, "w", encoding="utf-8") as f: + f.write("iteration beta phi_d phi_m time\n") + with open(filepath, "a", encoding="utf-8") as file: date_time = datetime.now().strftime("%b-%d-%Y:%H:%M:%S") @@ -427,11 +438,11 @@ def save_log(self): h5_object = w_s.get_entity(self.h5_object)[0] for file in [ - "SimPEG.out", - "SimPEG.log", - "ChiFactors.log", + ".out", + ".log", + ".chi", ]: - filepath = dirpath / file + filepath = dirpath / f"{self.base_name}{file}" if not filepath.is_file(): continue @@ -439,7 +450,7 @@ def save_log(self): with open(filepath, "rb") as f: raw_file = f.read() - file_entity = h5_object.get_entity(file)[0] + file_entity = h5_object.get_entity(f"{self.base_name}{file}")[0] if file_entity is None: file_entity = h5_object.add_file(filepath)