From 402f8a186e234d0cab103c74bfceaf7f5e50e292 Mon Sep 17 00:00:00 2001 From: alinzh Date: Sat, 14 Dec 2024 11:10:29 +0000 Subject: [PATCH 1/7] Change panda on polars in retrieve MPDS and tests --- mpds_client/retrieve_MPDS.py | 25 +++++++------------ mpds_client/test_retrieve_MPDS.py | 41 ++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py index d618ca0..e8f7dde 100755 --- a/mpds_client/retrieve_MPDS.py +++ b/mpds_client/retrieve_MPDS.py @@ -1,4 +1,3 @@ - import os import sys import time @@ -8,11 +7,11 @@ import httplib2 import ujson as json -import pandas as pd +import polars as pl from numpy import array_split import jmespath -from .errors import APIError +from errors import APIError use_pmg, use_ase = False, False @@ -107,7 +106,6 @@ class MPDSDataRetrieval(object): verbose = True debug = False - def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None): """ MPDS API consumer constructor @@ -127,7 +125,6 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug= self.verbose = verbose if verbose is not None else self.verbose self.debug = debug or self.debug - def _request(self, query, phases=None, page=0, pagesize=None): phases = ','.join([str(int(x)) for x in phases]) if phases else '' @@ -141,7 +138,7 @@ def _request(self, query, phases=None, page=0, pagesize=None): }) if self.debug: - print('curl -XGET -HKey:%s \'%s\'' % (self.api_key, uri)) + print('curl -XGET -HKey:%s \"%s\"' % (self.api_key, uri)) response, content = self.network.request( uri=uri, @@ -165,7 +162,6 @@ def _request(self, query, phases=None, page=0, pagesize=None): return content - def _massage(self, array, fields): if not fields: return array @@ -190,7 +186,6 @@ def _massage(self, array, fields): return output - def count_data(self, search, phases=None, **kwargs): """ Calculate the number of entries matching the keyword(s) specified @@ -217,7 +212,6 @@ def count_data(self, search, phases=None, **kwargs): return result['count'] - def get_data(self, search, phases=None, fields=default_fields): """ Retrieve data in JSON. @@ -287,7 +281,7 @@ def get_data(self, search, phases=None, fields=default_fields): if self.verbose: sys.stdout.write("\r\t%d%% of step %s from %s" % ( (counter/result['npages']) * 100, step, nsteps) - ) + ) sys.stdout.flush() tot_count += hits_count @@ -301,10 +295,9 @@ def get_data(self, search, phases=None, fields=default_fields): return output - def get_dataframe(self, *args, **kwargs): """ - Retrieve data as a Pandas dataframe. + Retrieve data as a Polars dataframe. Args: search: (dict) Search query like {"categ_A": "val_A", "categ_B": "val_B"}, @@ -316,7 +309,7 @@ def get_dataframe(self, *args, **kwargs): (if None is given, all the fields will be present) columns: (list) Column names for Pandas dataframe - Returns: (object) Pandas dataframe object containing the results + Returns: (object) Polars dataframe object containing the results """ columns = kwargs.get('columns') if columns: @@ -324,8 +317,8 @@ def get_dataframe(self, *args, **kwargs): else: columns = self.default_titles - return pd.DataFrame(self.get_data(*args, **kwargs), columns=columns) - + data = self.get_data(*args, **kwargs) + return pl.DataFrame(data, schema=columns) def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): search["props"] = "atomic structure" @@ -343,7 +336,6 @@ def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): return crystals - @staticmethod def compile_crystal(datarow, flavor='pmg'): """ @@ -412,3 +404,4 @@ def compile_crystal(datarow, flavor='pmg'): ) else: raise APIError("Crystal structure treatment unavailable") + diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py index c0c52a6..3cecae7 100755 --- a/mpds_client/test_retrieve_MPDS.py +++ b/mpds_client/test_retrieve_MPDS.py @@ -1,9 +1,7 @@ - import unittest #import warnings -import numpy as np -import pandas as pd +import polars as pl import httplib2 import ujson as json @@ -11,6 +9,7 @@ from jsonschema.exceptions import ValidationError from retrieve_MPDS import MPDSDataRetrieval +import logging class MPDSDataRetrievalTest(unittest.TestCase): @@ -69,7 +68,6 @@ def test_crystal_structure(self): self.assertEqual(len(ase_obj), 6) def test_get_crystals(self): - query = { "elements": "Ti-O", "classes": "binary", @@ -78,6 +76,7 @@ def test_get_crystals(self): } client = MPDSDataRetrieval() ntot = client.count_data(query) + logging.debug(f"Value of ntot: {ntot}") self.assertTrue(150 < ntot < 175) crystals = client.get_crystals(query, flavor='ase') @@ -115,9 +114,14 @@ def test_retrieval_of_phases(self): fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, columns=['Phid', 'Object'] ) - answer_one = answer_one[np.isfinite(answer_one['Phid'])] - phases_one = answer_one['Phid'].astype(int).tolist() - + if not(isinstance(answer_one, pl.DataFrame)): + print(type(answer_one)) + raise ValueError("answer_one is not a Polars DataFrame", type(answer_one)) + + answer_one = answer_one.filter(pl.col('Phid').is_not_null()) + answer_one = answer_one.with_columns(pl.col('Phid').cast(pl.Int32)) + phases_one = answer_one['Phid'].to_list() + self.assertTrue(len(phases_one) > client_one.maxnphases) result_one = client_one.get_dataframe( @@ -135,8 +139,12 @@ def test_retrieval_of_phases(self): fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, columns=['Phid', 'Object'] ) - answer_two = answer_two[np.isfinite(answer_two['Phid'])] - phases_two = answer_two['Phid'].astype(int).tolist() + if not(isinstance(answer_one, pl.DataFrame)): + print(type(answer_two)) + raise ValueError("answer_one is not a Polars DataFrame, is", type(answer_two)) + + answer_two = answer_two.filter(pl.col('Phid').is_not_null()) + phases_two = answer_two['Phid'].cast(pl.Int32).to_list() self.assertTrue(len(phases_two) < client_two.maxnphases) @@ -150,10 +158,13 @@ def test_retrieval_of_phases(self): self.assertEqual(len(result_one), len(result_two)) # check equality of result_one and result_two - merge = pd.concat([result_one, result_two]) - merge = merge.reset_index(drop=True) - merge_gpby = merge.groupby(list(merge.columns)) - idx = [x[0] for x in merge_gpby.groups.values() if len(x) == 1] - self.assertTrue(merge.reindex(idx).empty) + merge = pl.concat([result_one, result_two]) + merge = merge.with_columns(pl.Series("index", range(len(merge)))) + merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg(pl.len()) + idx = [x[0] for x in merge_gpby.iter_rows() if x[-1] == 1] + + self.assertTrue(merge.filter(pl.col("index").is_in(idx)).is_empty()) -if __name__ == "__main__": unittest.main() +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From dc3818c76958c68ffece83f641e7d4e2688d8e8f Mon Sep 17 00:00:00 2001 From: alinzh Date: Sat, 14 Dec 2024 11:49:19 +0000 Subject: [PATCH 2/7] Update export MPDS and add tests --- mpds_client/export_MPDS.py | 125 +++++++++++++++++--------------- mpds_client/test_export_MPDS.py | 64 ++++++++++++++++ 2 files changed, 132 insertions(+), 57 deletions(-) create mode 100644 mpds_client/test_export_MPDS.py diff --git a/mpds_client/export_MPDS.py b/mpds_client/export_MPDS.py index 824dfb4..f2858eb 100755 --- a/mpds_client/export_MPDS.py +++ b/mpds_client/export_MPDS.py @@ -4,9 +4,9 @@ """ import os import random - import ujson as json -import pandas as pd +import polars as pl +from typing import Union class MPDSExport(object): @@ -36,7 +36,9 @@ def _gen_basename(cls): return "".join(basename) @classmethod - def _get_title(cls, term): + def _get_title(cls, term: Union[str, int]): + if isinstance(term, int): + return str(term) return cls.human_names.get(term, term.capitalize()) @classmethod @@ -50,75 +52,84 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): cls._verify_export_dir() plot = {"use_visavis_type": plottype, "payload": {}} - if isinstance(data, pd.DataFrame): - iter_data = data.iterrows - pointers = columns - else: - iter_data = lambda: enumerate(data) - pointers = range(len(data[0])) + if not isinstance(data, pl.DataFrame): + raise TypeError("The 'data' parameter must be a Polars DataFrame") + + # сheck that columns are valid + if not all(col in data.columns for col in columns): + raise ValueError("Some specified columns are not in the DataFrame") if fmt == 'csv': + # export to CSV fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv") - f_export = open(fmt_export, "w") - f_export.write("%s\n" % ",".join(map(str, columns))) - for _, row in iter_data(): - f_export.write("%s\n" % ",".join([str(row[i]) for i in pointers])) - f_export.close() + with open(fmt_export, "w") as f_export: + f_export.write(",".join(columns) + "\n") + for row in data.select(columns).iter_rows(): + f_export.write(",".join(map(str, row)) + "\n") - else: + elif fmt == 'json': + # export to JSON fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json") - f_export = open(fmt_export, "w") - - if plottype == 'bar': - - plot["payload"] = {"x": [], "y": [], "xtitle": cls._get_title(columns[0]), "ytitle": cls._get_title(columns[1])} - - for _, row in iter_data(): - plot["payload"]["x"].append(row[pointers[0]]) - plot["payload"]["y"].append(row[pointers[1]]) - - elif plottype == 'plot3d': + with open(fmt_export, "w") as f_export: + if plottype == 'bar': + # bar plot payload + plot["payload"] = { + "x": [data[columns[0]].to_list()], + "y": data[columns[1]].to_list(), + "xtitle": cls._get_title(columns[0]), + "ytitle": cls._get_title(columns[1]) + } + + elif plottype == 'plot3d': + # 3D plot payload + plot["payload"] = { + "points": {"x": [], "y": [], "z": [], "labels": []}, + "meshes": [], + "xtitle": cls._get_title(columns[0]), + "ytitle": cls._get_title(columns[1]), + "ztitle": cls._get_title(columns[2]) + } + recent_mesh = None + for row in data.iter_rows(): + plot["payload"]["points"]["x"].append(row[data.columns.index(columns[0])]) + plot["payload"]["points"]["y"].append(row[data.columns.index(columns[1])]) + plot["payload"]["points"]["z"].append(row[data.columns.index(columns[2])]) + plot["payload"]["points"]["labels"].append(row[data.columns.index(columns[3])]) + + if row[data.columns.index(columns[4])] != recent_mesh: + plot["payload"]["meshes"].append({"x": [], "y": [], "z": []}) + recent_mesh = row[data.columns.index(columns[4])] + + if plot["payload"]["meshes"]: + plot["payload"]["meshes"][-1]["x"].append(row[data.columns.index(columns[0])]) + plot["payload"]["meshes"][-1]["y"].append(row[data.columns.index(columns[1])]) + plot["payload"]["meshes"][-1]["z"].append(row[data.columns.index(columns[2])]) + else: + raise RuntimeError(f"Error: {plottype} is an unknown plot type") + + if kwargs: + plot["payload"].update(kwargs) + + # write JSON to file + f_export.write(json.dumps(plot, escape_forward_slashes=False, indent=4)) - plot["payload"]["points"] = {"x": [], "y": [], "z": [], "labels": []} - plot["payload"]["meshes"] = [] - plot["payload"]["xtitle"] = cls._get_title(columns[0]) - plot["payload"]["ytitle"] = cls._get_title(columns[1]) - plot["payload"]["ztitle"] = cls._get_title(columns[2]) - recent_mesh = 0 - - for _, row in iter_data(): - plot["payload"]["points"]["x"].append(row[pointers[0]]) - plot["payload"]["points"]["y"].append(row[pointers[1]]) - plot["payload"]["points"]["z"].append(row[pointers[2]]) - plot["payload"]["points"]["labels"].append(row[pointers[3]]) - - if row[4] != recent_mesh: - plot["payload"]["meshes"].append({"x": [], "y": [], "z": []}) - recent_mesh = row[4] - - if plot["payload"]["meshes"]: - plot["payload"]["meshes"][-1]["x"].append(row[pointers[0]]) - plot["payload"]["meshes"][-1]["y"].append(row[pointers[1]]) - plot["payload"]["meshes"][-1]["z"].append(row[pointers[2]]) - - if kwargs: - plot["payload"].update(kwargs) - - else: raise RuntimeError("\r\nError: %s is an unknown plot type" % plottype) + else: + raise ValueError(f"Unsupported format: {fmt}") - f_export.write(json.dumps(plot, escape_forward_slashes=False, indent=4)) - f_export.close() + return fmt_export - return fmt_export @classmethod def save_df(cls, frame, tag): cls._verify_export_dir() + if not isinstance(frame, pl.DataFrame): + raise TypeError("Input frame must be a Polars DataFrame") + if tag is None: tag = '-' - pkl_export = os.path.join(cls.export_dir, 'df' + str(tag) + '_' + cls._gen_basename() + ".pkl") - frame.to_pickle(pkl_export, protocol=2) # Py2-3 compat + pkl_export = os.path.join(cls.export_dir, f'df{tag}_{cls._gen_basename()}.parquet') + frame.write_parquet(pkl_export) # cos pickle is not supported in polars return pkl_export @classmethod diff --git a/mpds_client/test_export_MPDS.py b/mpds_client/test_export_MPDS.py new file mode 100644 index 0000000..901087e --- /dev/null +++ b/mpds_client/test_export_MPDS.py @@ -0,0 +1,64 @@ +import unittest +import os +import polars as pl +from export_MPDS import MPDSExport + + +class TestMPDSExport(unittest.TestCase): + def test_save_plot_csv(self): + """Test saving a plot in CSV format.""" + data = pl.DataFrame({ + "length": [1.2, 1.5, 1.8, 2.0, 2.2], + "occurrence": [10, 15, 8, 20, 12] + }) + columns = ["length", "occurrence"] + plottype = "bar" + + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='csv') + self.assertTrue(os.path.isfile(exported_file)) + self.assertTrue(exported_file.endswith(".csv")) + + def test_save_plot_json(self): + """Test saving a plot in JSON format.""" + data = pl.DataFrame({ + "length": [1.2, 1.5, 1.8, 2.0, 2.2], + "occurrence": [10, 15, 8, 20, 12] + }) + columns = ["length", "occurrence"] + plottype = "bar" + + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + self.assertTrue(os.path.isfile(exported_file)) + self.assertTrue(exported_file.endswith(".json")) + + def test_save_plot_3d_json(self): + """Test saving a 3D plot in JSON format.""" + data = pl.DataFrame({ + "x": [1, 2, 3, 4], + "y": [5, 6, 7, 8], + "z": [9, 10, 11, 12], + "labels": ["A", "B", "C", "D"], + "meshes_id": [1, 1, 2, 2] + }) + columns = ["x", "y", "z", "labels", "meshes_id"] + plottype = "plot3d" + + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + self.assertTrue(os.path.isfile(exported_file)) + self.assertTrue(exported_file.endswith(".json")) + + def test_save_df(self): + """Test saving Polars DataFrame.""" + data = pl.DataFrame({ + "column1": [1, 2, 3], + "column2": [4, 5, 6] + }) + tag = "test" + + exported_file = MPDSExport.save_df(data, tag) + self.assertTrue(os.path.isfile(exported_file)) + self.assertTrue(exported_file.endswith(".parquet")) + + +if __name__ == "__main__": + unittest.main() From 421459305303300d66c06911cc1a41cb38284b2d Mon Sep 17 00:00:00 2001 From: alinzh Date: Sat, 14 Dec 2024 11:54:33 +0000 Subject: [PATCH 3/7] Polishing by linters --- mpds_client/__init__.py | 3 +- mpds_client/errors.py | 30 ++-- mpds_client/export_MPDS.py | 80 ++++++---- mpds_client/retrieve_MPDS.py | 234 +++++++++++++++++------------- mpds_client/test_export_MPDS.py | 43 +++--- mpds_client/test_retrieve_MPDS.py | 102 +++++++------ 6 files changed, 275 insertions(+), 217 deletions(-) diff --git a/mpds_client/__init__.py b/mpds_client/__init__.py index aa26463..00f8de4 100755 --- a/mpds_client/__init__.py +++ b/mpds_client/__init__.py @@ -1,4 +1,3 @@ - import sys from .retrieve_MPDS import MPDSDataTypes, APIError, MPDSDataRetrieval @@ -7,4 +6,4 @@ MIN_PY_VER = (3, 5) -assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER) \ No newline at end of file +assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER) diff --git a/mpds_client/errors.py b/mpds_client/errors.py index d20cdf1..de3d0ef 100755 --- a/mpds_client/errors.py +++ b/mpds_client/errors.py @@ -1,20 +1,20 @@ - class APIError(Exception): """ Simple error handling """ + codes = { - 204: 'No Results', - 400: 'Bad Request', - 401: 'Unauthorized', - 402: 'Unauthorized (Payment Required)', - 403: 'Forbidden', - 404: 'Not Found', - 413: 'Too Much Data Given', - 429: 'Too Many Requests (Rate Limiting)', - 500: 'Internal Server Error', - 501: 'Not Implemented', - 503: 'Service Unavailable' + 204: "No Results", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Unauthorized (Payment Required)", + 403: "Forbidden", + 404: "Not Found", + 413: "Too Much Data Given", + 429: "Too Many Requests (Rate Limiting)", + 500: "Internal Server Error", + 501: "Not Implemented", + 503: "Service Unavailable", } def __init__(self, msg, code=0): @@ -23,4 +23,8 @@ def __init__(self, msg, code=0): self.code = code def __str__(self): - return "HTTP error code %s: %s (%s)" % (self.code, self.codes.get(self.code, 'Communication Error'), self.msg) \ No newline at end of file + return "HTTP error code %s: %s (%s)" % ( + self.code, + self.codes.get(self.code, "Communication Error"), + self.msg, + ) diff --git a/mpds_client/export_MPDS.py b/mpds_client/export_MPDS.py index f2858eb..9f2ee57 100755 --- a/mpds_client/export_MPDS.py +++ b/mpds_client/export_MPDS.py @@ -2,6 +2,7 @@ Utilities for convenient exporting the MPDS data """ + import os import random import ujson as json @@ -10,13 +11,12 @@ class MPDSExport(object): - export_dir = "/tmp/_MPDS" human_names = { - 'length': 'Bond lengths, A', - 'occurrence': 'Counts', - 'bandgap': 'Band gap, eV' + "length": "Bond lengths, A", + "occurrence": "Counts", + "bandgap": "Band gap, eV", } @classmethod @@ -32,7 +32,11 @@ def _gen_basename(cls): basename = [] random.seed() for _ in range(12): - basename.append(random.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")) + basename.append( + random.choice( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + ) + ) return "".join(basename) @classmethod @@ -42,7 +46,7 @@ def _get_title(cls, term: Union[str, int]): return cls.human_names.get(term, term.capitalize()) @classmethod - def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): + def save_plot(cls, data, columns, plottype, fmt="json", **kwargs): """ Exports the data in the following formats for plotting: @@ -59,7 +63,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): if not all(col in data.columns for col in columns): raise ValueError("Some specified columns are not in the DataFrame") - if fmt == 'csv': + if fmt == "csv": # export to CSV fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv") with open(fmt_export, "w") as f_export: @@ -67,43 +71,59 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): for row in data.select(columns).iter_rows(): f_export.write(",".join(map(str, row)) + "\n") - elif fmt == 'json': + elif fmt == "json": # export to JSON fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json") with open(fmt_export, "w") as f_export: - if plottype == 'bar': + if plottype == "bar": # bar plot payload plot["payload"] = { "x": [data[columns[0]].to_list()], "y": data[columns[1]].to_list(), "xtitle": cls._get_title(columns[0]), - "ytitle": cls._get_title(columns[1]) + "ytitle": cls._get_title(columns[1]), } - elif plottype == 'plot3d': + elif plottype == "plot3d": # 3D plot payload plot["payload"] = { "points": {"x": [], "y": [], "z": [], "labels": []}, "meshes": [], "xtitle": cls._get_title(columns[0]), "ytitle": cls._get_title(columns[1]), - "ztitle": cls._get_title(columns[2]) + "ztitle": cls._get_title(columns[2]), } recent_mesh = None for row in data.iter_rows(): - plot["payload"]["points"]["x"].append(row[data.columns.index(columns[0])]) - plot["payload"]["points"]["y"].append(row[data.columns.index(columns[1])]) - plot["payload"]["points"]["z"].append(row[data.columns.index(columns[2])]) - plot["payload"]["points"]["labels"].append(row[data.columns.index(columns[3])]) + plot["payload"]["points"]["x"].append( + row[data.columns.index(columns[0])] + ) + plot["payload"]["points"]["y"].append( + row[data.columns.index(columns[1])] + ) + plot["payload"]["points"]["z"].append( + row[data.columns.index(columns[2])] + ) + plot["payload"]["points"]["labels"].append( + row[data.columns.index(columns[3])] + ) if row[data.columns.index(columns[4])] != recent_mesh: - plot["payload"]["meshes"].append({"x": [], "y": [], "z": []}) + plot["payload"]["meshes"].append( + {"x": [], "y": [], "z": []} + ) recent_mesh = row[data.columns.index(columns[4])] if plot["payload"]["meshes"]: - plot["payload"]["meshes"][-1]["x"].append(row[data.columns.index(columns[0])]) - plot["payload"]["meshes"][-1]["y"].append(row[data.columns.index(columns[1])]) - plot["payload"]["meshes"][-1]["z"].append(row[data.columns.index(columns[2])]) + plot["payload"]["meshes"][-1]["x"].append( + row[data.columns.index(columns[0])] + ) + plot["payload"]["meshes"][-1]["y"].append( + row[data.columns.index(columns[1])] + ) + plot["payload"]["meshes"][-1]["z"].append( + row[data.columns.index(columns[2])] + ) else: raise RuntimeError(f"Error: {plottype} is an unknown plot type") @@ -116,8 +136,7 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs): else: raise ValueError(f"Unsupported format: {fmt}") - return fmt_export - + return fmt_export @classmethod def save_df(cls, frame, tag): @@ -126,22 +145,25 @@ def save_df(cls, frame, tag): raise TypeError("Input frame must be a Polars DataFrame") if tag is None: - tag = '-' + tag = "-" - pkl_export = os.path.join(cls.export_dir, f'df{tag}_{cls._gen_basename()}.parquet') - frame.write_parquet(pkl_export) # cos pickle is not supported in polars + pkl_export = os.path.join( + cls.export_dir, "df" + str(tag) + "_" + cls._gen_basename() + ".pkl" + ) + frame.write_parquet(pkl_export) # cos pickle is not supported in polars return pkl_export @classmethod def save_model(cls, skmodel, tag): - import _pickle as cPickle cls._verify_export_dir() if tag is None: - tag = '-' + tag = "-" - pkl_export = os.path.join(cls.export_dir, 'ml' + str(tag) + '_' + cls._gen_basename() + ".pkl") - with open(pkl_export, 'wb') as f: + pkl_export = os.path.join( + cls.export_dir, "ml" + str(tag) + "_" + cls._gen_basename() + ".pkl" + ) + with open(pkl_export, "wb") as f: cPickle.dump(skmodel, f) return pkl_export diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py index e8f7dde..5cdf34c 100755 --- a/mpds_client/retrieve_MPDS.py +++ b/mpds_client/retrieve_MPDS.py @@ -18,22 +18,26 @@ try: from pymatgen.core.structure import Structure from pymatgen.core.lattice import Lattice + use_pmg = True -except ImportError: pass +except ImportError: + pass try: from ase import Atom from ase.spacegroup import crystal + use_ase = True -except ImportError: pass +except ImportError: + pass if not use_pmg and not use_ase: warnings.warn("Crystal structure treatment unavailable") -__author__ = 'Evgeny Blokhin ' -__copyright__ = 'Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics' -__license__ = 'MIT' +__author__ = "Evgeny Blokhin " +__copyright__ = "Copyright (c) 2020, Evgeny Blokhin, Tilde Materials Informatics" +__license__ = "MIT" class MPDSDataTypes(object): @@ -67,46 +71,51 @@ class MPDSDataRetrieval(object): *or* jsonobj = client.get_data({"formula":"SrTiO3"}, fields={}) """ + default_fields = { - 'S': [ - 'phase_id', - 'chemical_formula', - 'sg_n', - 'entry', - lambda: 'crystal structure', - lambda: 'angstrom' + "S": [ + "phase_id", + "chemical_formula", + "sg_n", + "entry", + lambda: "crystal structure", + lambda: "angstrom", ], - 'P': [ - 'sample.material.phase_id', - 'sample.material.chemical_formula', - 'sample.material.condition[0].scalar[0].value', - 'sample.material.entry', - 'sample.measurement[0].property.name', - 'sample.measurement[0].property.units', - 'sample.measurement[0].property.scalar' + "P": [ + "sample.material.phase_id", + "sample.material.chemical_formula", + "sample.material.condition[0].scalar[0].value", + "sample.material.entry", + "sample.measurement[0].property.name", + "sample.measurement[0].property.units", + "sample.measurement[0].property.scalar", ], - 'C': [ + "C": [ lambda: None, - 'title', + "title", lambda: None, - 'entry', - lambda: 'phase diagram', - 'naxes', - 'arity' - ] + "entry", + lambda: "phase diagram", + "naxes", + "arity", + ], } - default_titles = ['Phase', 'Formula', 'SG', 'Entry', 'Property', 'Units', 'Value'] + default_titles = ["Phase", "Formula", "SG", "Entry", "Property", "Units", "Value"] endpoint = "https://api.mpds.io/v0/download/facet" pagesize = 1000 - maxnpages = 120 # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM - maxnphases = 1500 # more phases require additional requests + maxnpages = ( + 120 # one hit may reach 50kB in RAM, consider pagesize*maxnpages*50kB free RAM + ) + maxnphases = 1500 # more phases require additional requests chillouttime = 2 # please, do not use values < 2, because the server may burn out verbose = True debug = False - def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None): + def __init__( + self, api_key=None, endpoint=None, dtype=None, verbose=None, debug=None + ): """ MPDS API consumer constructor @@ -116,7 +125,7 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug= Returns: None """ - self.api_key = api_key if api_key else os.environ['MPDS_KEY'] + self.api_key = api_key if api_key else os.environ["MPDS_KEY"] self.network = httplib2.Http() @@ -126,39 +135,42 @@ def __init__(self, api_key=None, endpoint=None, dtype=None, verbose=None, debug= self.debug = debug or self.debug def _request(self, query, phases=None, page=0, pagesize=None): - - phases = ','.join([str(int(x)) for x in phases]) if phases else '' - - uri = self.endpoint + '?' + urlencode({ - 'q': json.dumps(query), - 'phases': phases, - 'page': page, - 'pagesize': pagesize or self.pagesize, - 'dtype': self.dtype - }) + phases = ",".join([str(int(x)) for x in phases]) if phases else "" + + uri = ( + self.endpoint + + "?" + + urlencode( + { + "q": json.dumps(query), + "phases": phases, + "page": page, + "pagesize": pagesize or self.pagesize, + "dtype": self.dtype, + } + ) + ) if self.debug: - print('curl -XGET -HKey:%s \"%s\"' % (self.api_key, uri)) + print('curl -XGET -HKey:%s "%s"' % (self.api_key, uri)) response, content = self.network.request( - uri=uri, - method='GET', - headers={'Key': self.api_key} + uri=uri, method="GET", headers={"Key": self.api_key} ) if response.status != 200: - return {'error': content, 'code': response.status} + return {"error": content, "code": response.status} try: content = json.loads(content) except: - return {'error': 'Unreadable data obtained'} + return {"error": "Unreadable data obtained"} - if content.get('error'): - return {'error': content['error']} + if content.get("error"): + return {"error": content["error"]} - if not content['out']: - return {'error': 'No hits', 'code': 204} + if not content["out"]: + return {"error": "No hits", "code": 204} return content @@ -171,8 +183,8 @@ def _massage(self, array, fields): for item in array: filtered = [] - for object_type in ['S', 'P', 'C']: - if item['object_type'] == object_type: + for object_type in ["S", "P", "C"]: + if item["object_type"] == object_type: for expr in fields.get(object_type, []): if isinstance(expr, jmespath.parser.ParsedResult): filtered.append(expr.search(item)) @@ -201,16 +213,16 @@ def count_data(self, search, phases=None, **kwargs): """ result = self._request(search, phases=phases, pagesize=10) - if result['error']: - raise APIError(result['error'], result.get('code', 0)) + if result["error"]: + raise APIError(result["error"], result.get("code", 0)) - if result['npages'] > self.maxnpages: + if result["npages"] > self.maxnpages: warnings.warn( - "\r\nDataset is too big, you may risk to change maxnpages from %s to %s" % \ - (self.maxnpages, int(math.ceil(result['count']/self.pagesize))) + "\r\nDataset is too big, you may risk to change maxnpages from %s to %s" + % (self.maxnpages, int(math.ceil(result["count"] / self.pagesize))) ) - return result['count'] + return result["count"] def get_data(self, search, phases=None, fields=default_fields): """ @@ -232,56 +244,68 @@ def get_data(self, search, phases=None, fields=default_fields): documented at https://developer.mpds.io/#JSON-schemata """ output = [] - fields = { - key: [jmespath.compile(item) if isinstance(item, str) else item() for item in value] - for key, value in fields.items() - } if fields else None + fields = ( + { + key: [ + jmespath.compile(item) if isinstance(item, str) else item() + for item in value + ] + for key, value in fields.items() + } + if fields + else None + ) tot_count = 0 phases = list(set(phases)) if phases else [] if len(phases) > self.maxnphases: - all_phases = array_split(phases, int(math.ceil( - len(phases)/self.maxnphases - ))) - else: all_phases = [phases] + all_phases = array_split( + phases, int(math.ceil(len(phases) / self.maxnphases)) + ) + else: + all_phases = [phases] nsteps = len(all_phases) for step, current_phases in enumerate(all_phases, start=1): - counter, hits_count = 0, 0 while True: - result = self._request(search, phases=list(current_phases), page=counter) - if result['error']: - raise APIError(result['error'], result.get('code', 0)) + result = self._request( + search, phases=list(current_phases), page=counter + ) + if result["error"]: + raise APIError(result["error"], result.get("code", 0)) - if result['npages'] > self.maxnpages: + if result["npages"] > self.maxnpages: raise APIError( - "Too many hits (%s > %s), please, be more specific" % \ - (result['count'], self.maxnpages * self.pagesize), - 2 + "Too many hits (%s > %s), please, be more specific" + % (result["count"], self.maxnpages * self.pagesize), + 2, ) - output.extend(self._massage(result['out'], fields)) + output.extend(self._massage(result["out"], fields)) - if hits_count and hits_count != result['count']: - raise APIError("API error: hits count has been changed during the query") + if hits_count and hits_count != result["count"]: + raise APIError( + "API error: hits count has been changed during the query" + ) - hits_count = result['count'] + hits_count = result["count"] time.sleep(self.chillouttime) - if counter == result['npages'] - 1: + if counter == result["npages"] - 1: break counter += 1 if self.verbose: - sys.stdout.write("\r\t%d%% of step %s from %s" % ( - (counter/result['npages']) * 100, step, nsteps) - ) + sys.stdout.write( + "\r\t%d%% of step %s from %s" + % ((counter / result["npages"]) * 100, step, nsteps) + ) sys.stdout.flush() tot_count += hits_count @@ -311,24 +335,24 @@ def get_dataframe(self, *args, **kwargs): Returns: (object) Polars dataframe object containing the results """ - columns = kwargs.get('columns') + columns = kwargs.get("columns") if columns: - del kwargs['columns'] + del kwargs["columns"] else: columns = self.default_titles data = self.get_data(*args, **kwargs) return pl.DataFrame(data, schema=columns) - def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): + def get_crystals(self, search, phases=None, flavor="pmg", **kwargs): search["props"] = "atomic structure" crystals = [] for crystal_struct in self.get_data( - search, - phases, - fields={'S':['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']}, - **kwargs + search, + phases, + fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]}, + **kwargs, ): crobj = self.compile_crystal(crystal_struct, flavor) if crobj is not None: @@ -337,7 +361,7 @@ def get_crystals(self, search, phases=None, flavor='pmg', **kwargs): return crystals @staticmethod - def compile_crystal(datarow, flavor='pmg'): + def compile_crystal(datarow, flavor="pmg"): """ Helper method for representing the MPDS crystal structures in two flavors: either as a Pymatgen Structure object, or as an ASE Atoms object. @@ -376,20 +400,22 @@ def compile_crystal(datarow, flavor='pmg'): if len(datarow) < 4: raise ValueError( "Must supply a data row that ends with the entries " - "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'") + "'cell_abc', 'sg_n', 'basis_noneq', 'els_noneq'" + ) - cell_abc, sg_n, basis_noneq, els_noneq = \ - datarow[-4], int(datarow[-3]), datarow[-2], datarow[-1] + cell_abc, sg_n, basis_noneq, els_noneq = ( + datarow[-4], + int(datarow[-3]), + datarow[-2], + datarow[-1], + ) - if flavor == 'pmg' and use_pmg: + if flavor == "pmg" and use_pmg: return Structure.from_spacegroup( - sg_n, - Lattice.from_parameters(*cell_abc), - els_noneq, - basis_noneq + sg_n, Lattice.from_parameters(*cell_abc), els_noneq, basis_noneq ) - elif flavor == 'ase' and use_ase: + elif flavor == "ase" and use_ase: atom_data = [] for num, i in enumerate(basis_noneq): @@ -400,8 +426,8 @@ def compile_crystal(datarow, flavor='pmg'): spacegroup=sg_n, cellpar=cell_abc, primitive_cell=True, - onduplicates='replace' + onduplicates="replace", ) - else: raise APIError("Crystal structure treatment unavailable") - + else: + raise APIError("Crystal structure treatment unavailable") diff --git a/mpds_client/test_export_MPDS.py b/mpds_client/test_export_MPDS.py index 901087e..fbe8d94 100644 --- a/mpds_client/test_export_MPDS.py +++ b/mpds_client/test_export_MPDS.py @@ -1,58 +1,55 @@ import unittest import os import polars as pl -from export_MPDS import MPDSExport +from export_MPDS import MPDSExport class TestMPDSExport(unittest.TestCase): def test_save_plot_csv(self): """Test saving a plot in CSV format.""" - data = pl.DataFrame({ - "length": [1.2, 1.5, 1.8, 2.0, 2.2], - "occurrence": [10, 15, 8, 20, 12] - }) + data = pl.DataFrame( + {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]} + ) columns = ["length", "occurrence"] plottype = "bar" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='csv') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="csv") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".csv")) def test_save_plot_json(self): """Test saving a plot in JSON format.""" - data = pl.DataFrame({ - "length": [1.2, 1.5, 1.8, 2.0, 2.2], - "occurrence": [10, 15, 8, 20, 12] - }) + data = pl.DataFrame( + {"length": [1.2, 1.5, 1.8, 2.0, 2.2], "occurrence": [10, 15, 8, 20, 12]} + ) columns = ["length", "occurrence"] plottype = "bar" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".json")) def test_save_plot_3d_json(self): """Test saving a 3D plot in JSON format.""" - data = pl.DataFrame({ - "x": [1, 2, 3, 4], - "y": [5, 6, 7, 8], - "z": [9, 10, 11, 12], - "labels": ["A", "B", "C", "D"], - "meshes_id": [1, 1, 2, 2] - }) + data = pl.DataFrame( + { + "x": [1, 2, 3, 4], + "y": [5, 6, 7, 8], + "z": [9, 10, 11, 12], + "labels": ["A", "B", "C", "D"], + "meshes_id": [1, 1, 2, 2], + } + ) columns = ["x", "y", "z", "labels", "meshes_id"] plottype = "plot3d" - exported_file = MPDSExport.save_plot(data, columns, plottype, fmt='json') + exported_file = MPDSExport.save_plot(data, columns, plottype, fmt="json") self.assertTrue(os.path.isfile(exported_file)) self.assertTrue(exported_file.endswith(".json")) def test_save_df(self): """Test saving Polars DataFrame.""" - data = pl.DataFrame({ - "column1": [1, 2, 3], - "column2": [4, 5, 6] - }) + data = pl.DataFrame({"column1": [1, 2, 3], "column2": [4, 5, 6]}) tag = "test" exported_file = MPDSExport.save_df(data, tag) diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py index 3cecae7..898354b 100755 --- a/mpds_client/test_retrieve_MPDS.py +++ b/mpds_client/test_retrieve_MPDS.py @@ -1,5 +1,5 @@ import unittest -#import warnings +# import warnings import polars as pl @@ -15,22 +15,23 @@ class MPDSDataRetrievalTest(unittest.TestCase): @classmethod def setUpClass(cls): - #warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") + # warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") network = httplib2.Http() - response, content = network.request('https://developer.mpds.io/mpds.schema.json') + response, content = network.request( + "https://developer.mpds.io/mpds.schema.json" + ) assert response.status == 200 cls.schema = json.loads(content) Draft4Validator.check_schema(cls.schema) def test_valid_answer(self): - query = { "elements": "K-Ag", "classes": "iodide", "props": "heat capacity", - "lattices": "cubic" + "lattices": "cubic", } client = MPDSDataRetrieval() @@ -40,30 +41,28 @@ def test_valid_answer(self): validate(answer, self.schema) except ValidationError as e: self.fail( - "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s" % ( - e.instance, e.context - ) + "The item: \r\n\r\n %s \r\n\r\n has an issue: \r\n\r\n %s" + % (e.instance, e.context) ) def test_crystal_structure(self): - query = { "elements": "Ti-O", "classes": "binary", "props": "atomic structure", - "sgs": 136 + "sgs": 136, } client = MPDSDataRetrieval() ntot = client.count_data(query) self.assertTrue(150 < ntot < 175) - for crystal_struct in client.get_data(query, fields={ - 'S': ['cell_abc', 'sg_n', 'basis_noneq', 'els_noneq']}): - + for crystal_struct in client.get_data( + query, fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]} + ): self.assertEqual(crystal_struct[1], 136) - ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, 'ase') + ase_obj = MPDSDataRetrieval.compile_crystal(crystal_struct, "ase") if ase_obj: self.assertEqual(len(ase_obj), 6) @@ -72,20 +71,22 @@ def test_get_crystals(self): "elements": "Ti-O", "classes": "binary", "props": "atomic structure", - "sgs": 136 + "sgs": 136, } client = MPDSDataRetrieval() ntot = client.count_data(query) logging.debug(f"Value of ntot: {ntot}") self.assertTrue(150 < ntot < 175) - crystals = client.get_crystals(query, flavor='ase') + crystals = client.get_crystals(query, flavor="ase") for crystal in crystals: self.assertIsNotNone(crystal) # now try getting the crystal from the phase_id(s) - phase_ids = {_[0] for _ in client.get_data(query, fields={'S': ['phase_id']})} - crystals_from_phase_ids = client.get_crystals(query, phases=phase_ids, flavor='ase') + phase_ids = {_[0] for _ in client.get_data(query, fields={"S": ["phase_id"]})} + crystals_from_phase_ids = client.get_crystals( + query, phases=phase_ids, flavor="ase" + ) self.assertEqual(len(crystals), len(crystals_from_phase_ids)) @@ -95,15 +96,11 @@ def test_retrieval_of_phases(self): in two ways: maxnphases = changed and maxnphases = default """ - query_a = { - "elements": "O", - "classes": "binary", - "props": "band gap" - } + query_a = {"elements": "O", "classes": "binary", "props": "band gap"} query_b = { "elements": "O", "classes": "binary", - "props": "isothermal bulk modulus" + "props": "isothermal bulk modulus", } client_one = MPDSDataRetrieval() @@ -111,24 +108,28 @@ def test_retrieval_of_phases(self): answer_one = client_one.get_dataframe( query_a, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'] + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], ) - if not(isinstance(answer_one, pl.DataFrame)): + if not (isinstance(answer_one, pl.DataFrame)): print(type(answer_one)) raise ValueError("answer_one is not a Polars DataFrame", type(answer_one)) - answer_one = answer_one.filter(pl.col('Phid').is_not_null()) - answer_one = answer_one.with_columns(pl.col('Phid').cast(pl.Int32)) - phases_one = answer_one['Phid'].to_list() - + answer_one = answer_one.filter(pl.col("Phid").is_not_null()) + answer_one = answer_one.with_columns(pl.col("Phid").cast(pl.Int32)) + phases_one = answer_one["Phid"].to_list() + self.assertTrue(len(phases_one) > client_one.maxnphases) result_one = client_one.get_dataframe( query_b, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'], - phases=phases_one + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], + phases=phases_one, ) client_two = MPDSDataRetrieval() @@ -136,23 +137,29 @@ def test_retrieval_of_phases(self): answer_two = client_two.get_dataframe( query_a, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'] + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], ) - if not(isinstance(answer_one, pl.DataFrame)): + if not (isinstance(answer_one, pl.DataFrame)): print(type(answer_two)) - raise ValueError("answer_one is not a Polars DataFrame, is", type(answer_two)) - - answer_two = answer_two.filter(pl.col('Phid').is_not_null()) - phases_two = answer_two['Phid'].cast(pl.Int32).to_list() + raise ValueError( + "answer_one is not a Polars DataFrame, is", type(answer_two) + ) + + answer_two = answer_two.filter(pl.col("Phid").is_not_null()) + phases_two = answer_two["Phid"].cast(pl.Int32).to_list() self.assertTrue(len(phases_two) < client_two.maxnphases) result_two = client_two.get_dataframe( query_b, - fields={'P': ['sample.material.phase_id', 'sample.material.chemical_formula']}, - columns=['Phid', 'Object'], - phases=phases_two + fields={ + "P": ["sample.material.phase_id", "sample.material.chemical_formula"] + }, + columns=["Phid", "Object"], + phases=phases_two, ) self.assertEqual(len(result_one), len(result_two)) @@ -160,11 +167,14 @@ def test_retrieval_of_phases(self): # check equality of result_one and result_two merge = pl.concat([result_one, result_two]) merge = merge.with_columns(pl.Series("index", range(len(merge)))) - merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg(pl.len()) + merge_gpby = merge.group_by(list(merge.columns), maintain_order=True).agg( + pl.len() + ) idx = [x[0] for x in merge_gpby.iter_rows() if x[-1] == 1] self.assertTrue(merge.filter(pl.col("index").is_in(idx)).is_empty()) -if __name__ == "__main__": + +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From ab209433e835c4f64548db2d3a12133b33edde26 Mon Sep 17 00:00:00 2001 From: alinzh Date: Thu, 19 Dec 2024 18:01:10 +0000 Subject: [PATCH 4/7] Change number of results for test --- mpds_client/test_retrieve_MPDS.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py index 898354b..9319a83 100755 --- a/mpds_client/test_retrieve_MPDS.py +++ b/mpds_client/test_retrieve_MPDS.py @@ -55,7 +55,7 @@ def test_crystal_structure(self): client = MPDSDataRetrieval() ntot = client.count_data(query) - self.assertTrue(150 < ntot < 175) + self.assertTrue(90 < ntot < 200) for crystal_struct in client.get_data( query, fields={"S": ["cell_abc", "sg_n", "basis_noneq", "els_noneq"]} @@ -76,7 +76,7 @@ def test_get_crystals(self): client = MPDSDataRetrieval() ntot = client.count_data(query) logging.debug(f"Value of ntot: {ntot}") - self.assertTrue(150 < ntot < 175) + self.assertTrue(190 < ntot < 200) crystals = client.get_crystals(query, flavor="ase") for crystal in crystals: From 2e0e1faace112351faa1019e8bea567c6b82ee72 Mon Sep 17 00:00:00 2001 From: alinzh Date: Sun, 17 Aug 2025 13:20:07 +0000 Subject: [PATCH 5/7] Add download method for ab initio loggs --- mpds_client/retrieve_MPDS.py | 170 +++++++++++++++++++++++++++++++++-- 1 file changed, 164 insertions(+), 6 deletions(-) diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py index 5cdf34c..7c0e68d 100755 --- a/mpds_client/retrieve_MPDS.py +++ b/mpds_client/retrieve_MPDS.py @@ -1,23 +1,27 @@ +import logging +import math import os +import shutil import sys +import tempfile import time -import math import warnings +from pathlib import Path from urllib.parse import urlencode import httplib2 -import ujson as json -import polars as pl -from numpy import array_split import jmespath - +import polars as pl +import requests +import ujson as json from errors import APIError +from numpy import array_split use_pmg, use_ase = False, False try: - from pymatgen.core.structure import Structure from pymatgen.core.lattice import Lattice + from pymatgen.core.structure import Structure use_pmg = True except ImportError: @@ -431,3 +435,157 @@ def compile_crystal(datarow, flavor="pmg"): else: raise APIError("Crystal structure treatment unavailable") + + def download_ab_initio_logs( + self, + search: dict, + save_dir: Path, + keep_archives: bool = False, + timeout: int = 30, + ): + """ + Download ab initio simulation logs (CRYSTAL .out and Fleur .xml) for materials matching the search criteria. + + Args: + search (dict): Search query like {"props": "electrical conductivity"} + save_dir (str|Path): Directory to save downloaded logs + keep_archives (bool): Whether to keep downloaded archive files + timeout (int): Timeout for download requests in seconds + + Returns: + list: Paths to downloaded log files + """ + try: + from dft_organizer.re_archiver import extract_7z + except ImportError: + raise ImportError( + "dft_organizer package is required for ab initio logs download." + ) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + archive_dir = save_dir / "temp_archives" + archive_dir.mkdir(exist_ok=True) + + # aetup logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(save_dir / "ab_initio_downloader.log"), + ], + ) + logger = logging.getLogger("MPDSDataRetrieval") + + # get URLs + fields = { + "P": [ + "sample.material.entry", + "sample.material.phase_id", + "sample.measurement[0].raw_data", + ] + } + data = self.get_data(search, fields=fields) + + if not data: + logger.warning("No ab initio data found matching the search criteria") + return [] + + saved_files = [] + for item in data: + material_id = item[0] + phase_id = item[1] + archive_url = item[2] + + if not archive_url: + logger.warning(f"No archive URL for material {material_id}") + continue + + logger.info(f"Processing material {material_id}") + + try: + # download archive + response = requests.get(archive_url, timeout=timeout) + response.raise_for_status() + + # save archive + archive_path = archive_dir / f"material_{material_id}.7z" + with open(archive_path, "wb") as f: + f.write(response.content) + logger.info(f"Saved archive: {archive_path}") + + # unpack + material_files = self._extract_logs( + archive_path, material_id, phase_id, save_dir, extract_7z, logger + ) + saved_files.extend(material_files) + logger.info( + f"Extracted {len(material_files)} logs for material {material_id}" + ) + + # delete archive if not keeping archives + if not keep_archives: + archive_path.unlink() + + except Exception as e: + logger.error(f"Error processing material {material_id}: {str(e)}") + + # delete temp archive dir if not keeping archives + if not keep_archives: + shutil.rmtree(archive_dir, ignore_errors=True) + + logger.info(f"Downloaded {len(saved_files)} log files in total") + return saved_files + + def _extract_logs( + self, archive_path: Path, material_id: str, phase_id: str, save_dir: Path, extract_func, logger + ): + """Extract engines logs by extract_7z""" + material_dir = save_dir / f"material_{material_id}" + material_dir.mkdir(exist_ok=True) + saved_files = [] + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # unpack archive + success = extract_func(archive_path, tmp_path) + if not success: + logger.error(f"Failed to extract archive: {archive_path.name}") + return saved_files + + logger.debug(f"Extracted archive: {archive_path.name} to {tmp_path}") + + # try to find and save log files + for file_path in tmp_path.rglob("*"): + if not file_path.is_file(): + continue + + # check if the file is a log file + if ( + file_path.suffix in (".out", ".xml") + or file_path.name == "SIGMA.DAT" + or "TRANSPORT" in file_path.parts + ): + + # create new name with phase info + new_name = f"phase_{phase_id}_{file_path.name}" + dest_path = material_dir / new_name + + shutil.copy2(file_path, dest_path) + saved_files.append(dest_path) + logger.info(f"Saved log: {dest_path}") + + return saved_files + + +if __name__ == "__main__": + client = MPDSDataRetrieval(dtype=MPDSDataTypes.AB_INITIO) + downloaded_files = client.download_ab_initio_logs( + search={"props": "electrical conductivity"}, + save_dir="./ab_initio_logs", + keep_archives=False, + ) + + print(f"Downloaded {len(downloaded_files)} log files") From 76e847098fbaab913bda69ca92e0ca8f047b66f2 Mon Sep 17 00:00:00 2001 From: alinzh Date: Sun, 17 Aug 2025 16:25:16 +0000 Subject: [PATCH 6/7] Add unittest --- mpds_client/test_retrieve_MPDS.py | 70 ++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/mpds_client/test_retrieve_MPDS.py b/mpds_client/test_retrieve_MPDS.py index 9319a83..ec58959 100755 --- a/mpds_client/test_retrieve_MPDS.py +++ b/mpds_client/test_retrieve_MPDS.py @@ -1,15 +1,17 @@ +import logging +import os +import tempfile import unittest -# import warnings - -import polars as pl +from pathlib import Path import httplib2 +import polars as pl import ujson as json -from jsonschema import validate, Draft4Validator +from jsonschema import Draft4Validator, validate from jsonschema.exceptions import ValidationError +from retrieve_MPDS import MPDSDataRetrieval, MPDSDataTypes -from retrieve_MPDS import MPDSDataRetrieval -import logging +# import warnings class MPDSDataRetrievalTest(unittest.TestCase): @@ -174,6 +176,62 @@ def test_retrieval_of_phases(self): self.assertTrue(merge.filter(pl.col("index").is_in(idx)).is_empty()) + def test_download_ab_initio_logs_real_simple(self): + """ + Simple real download test for ab initio logs + Downloads logs for one material and verifies results + """ + with tempfile.TemporaryDirectory() as tmp_dir: + save_dir = Path(tmp_dir) + + client = MPDSDataRetrieval(dtype=MPDSDataTypes.AB_INITIO) + query = { + "props": "electrical conductivity", + } + + # execute download with timeout + try: + downloaded_files = client.download_ab_initio_logs( + search=query, save_dir=save_dir, timeout=120 + ) + + if not downloaded_files: + self.skipTest("No data found for the test query") + + # check that some files were downloaded + self.assertGreater(len(downloaded_files), 0, "No files downloaded") + + # check that files exist + for file_path in downloaded_files: + self.assertTrue( + file_path.exists(), f"File {file_path} does not exist" + ) + self.assertGreater( + os.path.getsize(file_path), 100, "File is too small" + ) + + # check directory structure + material_dirs = list(save_dir.glob("material_*")) + self.assertTrue(material_dirs, "No material directories created") + + # check for expected file types + found_out = False + found_dat = False + for file_path in downloaded_files: + if file_path.suffix == ".out": + found_out = True + if file_path.name.endswith("SIGMA.DAT"): + found_dat = True + + self.assertTrue(found_out or found_dat, "No expected log files found") + + logging.info( + f"Successfully downloaded {len(downloaded_files)} log files" + ) + + except Exception as e: + self.fail(f"Download failed: {str(e)}") + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) From 18939a9b929a6b589ce6a60601b5db4bee5a45b4 Mon Sep 17 00:00:00 2001 From: alinzh Date: Thu, 4 Sep 2025 09:01:00 +0000 Subject: [PATCH 7/7] Remove dft_organizer --- mpds_client/retrieve_MPDS.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mpds_client/retrieve_MPDS.py b/mpds_client/retrieve_MPDS.py index 7c0e68d..f718050 100755 --- a/mpds_client/retrieve_MPDS.py +++ b/mpds_client/retrieve_MPDS.py @@ -12,6 +12,7 @@ import httplib2 import jmespath import polars as pl +import py7zr import requests import ujson as json from errors import APIError @@ -435,6 +436,17 @@ def compile_crystal(datarow, flavor="pmg"): else: raise APIError("Crystal structure treatment unavailable") + + @staticmethod + def extract_7z(archive_path: Path, target_dir: Path): + """Unpack 7z archive to target dir""" + try: + with py7zr.SevenZipFile(archive_path, mode='r') as archive: + archive.extractall(target_dir) + return True + except Exception as e: + print(f"Error during unpack {archive_path}: {e}") + return False def download_ab_initio_logs( self, @@ -455,13 +467,6 @@ def download_ab_initio_logs( Returns: list: Paths to downloaded log files """ - try: - from dft_organizer.re_archiver import extract_7z - except ImportError: - raise ImportError( - "dft_organizer package is required for ab initio logs download." - ) - save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) archive_dir = save_dir / "temp_archives" @@ -517,7 +522,7 @@ def download_ab_initio_logs( # unpack material_files = self._extract_logs( - archive_path, material_id, phase_id, save_dir, extract_7z, logger + archive_path, material_id, phase_id, save_dir, self.extract_7z, logger ) saved_files.extend(material_files) logger.info( @@ -541,7 +546,7 @@ def download_ab_initio_logs( def _extract_logs( self, archive_path: Path, material_id: str, phase_id: str, save_dir: Path, extract_func, logger ): - """Extract engines logs by extract_7z""" + """Extract engines logs""" material_dir = save_dir / f"material_{material_id}" material_dir.mkdir(exist_ok=True) saved_files = []