diff --git a/src/reportengine/app.py b/src/reportengine/app.py index 84e0565..714da5e 100644 --- a/src/reportengine/app.py +++ b/src/reportengine/app.py @@ -175,9 +175,12 @@ def argparser(self): help='matplotlib style file to override the built-in one.', default=None) - parser.add_argument('--formats', nargs='+', help="formats of the output figures", + parser.add_argument('--figure-formats', nargs='+', help="formats of the output figures", default=('png', 'pdf',)) + parser.add_argument('--table-formats', nargs='+', default=('csv',), choices=["parquet", "csv"], + help="Format to save tables as. Note csv is the only human readable format.") + parser.add_argument('-x', '--extra-providers', nargs='+', help="additional providers from which to " "load actions. Must be an importable specifiaction.") @@ -281,6 +284,16 @@ def init(self, cmdline=None): import faulthandler faulthandler.enable() args = self.get_commandline_arguments(cmdline) + if 'parquet' in args['table_formats']: + try: + import pyarrow + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Failed to import module pyarrow. " + "This is a required dependency to save " + "tables in the parquet format. " + "Please run conda install pyarrow and try again." + ) self.init_logging(args) sys.excepthook = self.excepthook try: diff --git a/src/reportengine/environment.py b/src/reportengine/environment.py index 4e567a8..004dd1a 100644 --- a/src/reportengine/environment.py +++ b/src/reportengine/environment.py @@ -30,7 +30,8 @@ class EnvironmentError_(Exception): pass } class Environment: - def __init__(self, *, output=None, formats=('pdf',), + def __init__(self, *, output=None, + figure_formats=('pdf',), table_formats=('csv',), default_figure_format=None, loglevel=logging.DEBUG, config_yml = None, **kwargs): @@ -38,7 +39,9 @@ def __init__(self, *, output=None, formats=('pdf',), self.output_path = pathlib.Path(output).absolute() else: self.output_path = output - self.figure_formats = formats + + self.table_formats = table_formats + self.figure_formats = figure_formats self._default_figure_format = default_figure_format self.loglevel = loglevel self.extra_args = kwargs @@ -111,6 +114,10 @@ def get_figure_paths(self, handle): for fmt in self.figure_formats: yield self.figure_folder / (handle + '.' + fmt) + def get_table_paths(self, handle): + for fmt in self.table_formats: + yield self.table_folder / (handle + '.' + fmt) + @classmethod def ns_dump_description(cls): return dict( diff --git a/src/reportengine/table.py b/src/reportengine/table.py index c5ec6b8..18eb3a1 100644 --- a/src/reportengine/table.py +++ b/src/reportengine/table.py @@ -57,17 +57,46 @@ def as_markdown(self): res = re.sub('\n\s+', '\n', res) return res - - -def prepare_path(*, spec, namespace,environment, **kwargs): - name = spec_to_nice_name(namespace, spec) - path = environment.table_folder / (name + '.csv') - return {'path': path} - -def savetable(df, path): +def str_columns(df): + log.debug("Changing column types to str") + cols = df.columns + if isinstance(cols, pd.MultiIndex): + for i in range(cols.nlevels): + str_col = cols.levels[i].astype(str) + # Could use inplace but it's + # going to bedeprecated + cols = cols.set_levels(str_col, i) + else: + cols = cols.astype(str) + df.columns = cols + return df + +def prepare_path(*, spec, namespace, environment, **kwargs): + paths = environment.get_table_paths(spec_to_nice_name(namespace, spec)) + return {'paths': list(paths)} + +def savetable(df, paths): """Final action to save figures, with a nice filename""" - log.debug("Writing table %s" % path) - df.to_csv(str(path), sep='\t', na_rep='nan') + for path in paths: + log.debug("Writing table %s" % path) + format = path.suffix[1:] + if format == "parquet": # Default to parquet format + try: + df.to_parquet(str(path)) + except ValueError as e: + # Need to change the type of each level to str + raise ValueError( + "To save a table in parquet format the column entries must all be of type str. " + "Consider using the helper function reportengine.table.str_columns before passing the " + "dataframe to the savetable function." + ) from e + elif format == "csv": + df.to_csv(str(path), sep='\t', na_rep='nan') + else: + raise NotImplementedError( + f"Unrecognised format {format}", + "choose one of parquet or csv" + ) return Table.fromdf(df, path=path) def savetablelist(dfs, path):