From 1fdd6443528435a713060b2c7bd79f340e3a5985 Mon Sep 17 00:00:00 2001 From: martinkilbinger Date: Wed, 12 Nov 2025 11:55:28 +0100 Subject: [PATCH 1/3] added script to plot r-band magnitude histograms for different selections; cleaned up plotting functions --- .../cosmo_val/catalog_paper_plot/hist_mag.py | 384 ++++++++++++++++++ src/sp_validation/__init__.py | 22 +- src/sp_validation/basic.py | 2 - src/sp_validation/plots.py | 83 ++++ src/sp_validation/run_joint_cat.py | 102 +---- 5 files changed, 490 insertions(+), 103 deletions(-) create mode 100644 notebooks/cosmo_val/catalog_paper_plot/hist_mag.py diff --git a/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py new file mode 100644 index 0000000..3ee2c88 --- /dev/null +++ b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py @@ -0,0 +1,384 @@ +# %% +# hist_mag.py +# +# Plot magnitude histogram for various cuts and selection criteria + +# %% +import matplotlib +import matplotlib.pylab as plt + +# enable autoreload for interactive sessions +from IPython import get_ipython +ipython = get_ipython() +if ipython is not None: + ipython.run_line_magic("matplotlib", "inline") + ipython.run_line_magic("reload_ext", "autoreload") + ipython.run_line_magic("autoreload", "2") + ipython.run_line_magic("reload_ext", "log_cell_time") + +# %% +import sys +import os +import re +import numpy as np +from astropy.io import fits +from io import StringIO + +from sp_validation import run_joint_cat as sp_joint +from sp_validation import util +from sp_validation.basic import metacal +from sp_validation import calibration +import sp_validation.cat as cat + +# %% +# Initialize calibration class instance +obj = sp_joint.CalibrateCat() + +config = obj.read_config_set_params("config_mask.yaml") + +hist_data_path = "magnitude_histograms_data.npz" + +test_only = False + +# %% +def get_data(obj, test_only=False): + + + # Get data. Set load_into_memory to False for very large files + dat, dat_ext = obj.read_cat(load_into_memory=False) + + if test_only: + n_max = 1_000_000 + print(f"MKDEBUG testing only first {n_max} objects") + dat = dat[:n_max] + dat_ext = dat_ext[:n_max] + + return dat, dat_ext + + +def read_hist_data(hist_data_path): + """ + Read histogram data from npz file. + + Parameters + ---------- + hist_data_path : str + Path to the npz file containing histogram data + + Returns + ------- + hist_data : dict + Dictionary with keys for each selection criterion containing: + - 'counts': histogram counts + - 'bins': bin edges + - 'label': label for the histogram + """ + loaded = np.load(hist_data_path, allow_pickle=True) + hist_data = {} + + for key in loaded.files: + data = loaded[key] + hist_data[key] = { + 'counts': data[0], + 'bins': data[1], + 'label': str(data[2]) + } + + return hist_data + + + +def get_mask(masks, col_name): + + # Get mask fomr masks with col_name = col_name + for mask in masks: + if mask._col_name == col_name: + return mask + + +def compute_hist(masks, col_name, mask_cumul, mag, bins): + """ + Compute histogram for given mask and magnitude data. + + Parameters + ---------- + masks : list + List of mask objects + col_name : str + Column name to identify the mask + mask_cumul : array or None + Cumulative mask array + mag : array + Magnitude data + bins : array + Bin edges for histogram + + Returns + ------- + counts : array + Histogram counts + bins : array + Bin edges + label : str + Label for the histogram + n_valid : int + Number of valid data points after masking + mask_cumul : array + Updated cumulative mask array + """ + this_mask = get_mask(masks, col_name) + + # First time: + if mask_cumul is None: + print("Init mask_cumul") + mask_cumul = this_mask._mask + else: + mask_cumul &= this_mask._mask + + # Data values + my_mag = mag[mask_cumul] + + # Data count + n_valid = np.sum(mask_cumul) + + # Label + string_buffer = StringIO() + this_mask.print_condition(string_buffer) + label = string_buffer.getvalue().strip() + if label == "": + label = col_name + + counts, bin_edges = np.histogram(my_mag, bins=bins) + + return counts, bin_edges, label, n_valid, mask_cumul + + +def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): + """ + Plot histogram from counts and bins using bar plot. + + Parameters + ---------- + counts : array + Histogram counts + bins : array + Bin edges + label : str + Label for the histogram + alpha : float + Transparency for plot + ax : matplotlib axis + Axis to plot on + color : str or None + Color for the histogram + + """ + bin_centers = 0.5 * (bins[:-1] + bins[1:]) + ax.bar( + bin_centers, + counts, + width=np.diff(bins), + alpha=alpha, + label=label, + align='center', + color=color + ) + +# %% + +if os.path.exists(hist_data_path): + print(f"Histogram data file {hist_data_path} found.") + print("Reading and plotting.") + + dat = dat_ext = None + hist_data = read_hist_data(hist_data_path) + +else: + print(f"Histogram data file {hist_data_path} not found.") + print("Reading UNIONS cat and computing.") + + dat, dat_ext = get_data(obj, test_only=test_only) + hist_data = None + + +# ## Masking +# %% +# Get all masks, with or without dat, dat_ext +masks, labels = sp_joint.get_masks_from_config( + config, + dat, + dat_ext, + verbose=obj._params["verbose"], +) + +# %% +# List of basic masks to apply to all cases +masks_labels_basic = [ + "FLAGS", + "overlap", + "IMAFLAGS_ISO", + "mag", + "NGMIX_MOM_FAIL", + "NGMIX_ELL_PSFo_NOSHEAR_0", + "NGMIX_ELL_PSFo_NOSHEAR_1", + "4_Stars", + "8_Manual", + "64_r", + "1024_Maximask", +] + +# %% +masks_basic = [] +for mask in masks: + if mask._col_name in masks_labels_basic: + masks_basic.append(mask) + +if dat is not None: + print("Creating combined basic mask") + masks_basic_combined = sp_joint.Mask.from_list( + masks_basic, + label="combined", + verbose=obj._params["verbose"], + ) +else: + # Create dummy combined mask (for plot label) + print("Creating dummy combined basic mask") + masks_basic_combined = sp_joint.Mask( + "combined", + "combined", + kind="none", + ) + +masks.append(masks_basic_combined) + +# %% +# Metacal mask (cuts) +mask_tmp = sp_joint.Mask( + "metacal", + "metacal", + kind="none", +) + +# %% +if dat is not None: + cm = config["metacal"] + gal_metacal = metacal( + dat, + masks_basic_combined._mask, + snr_min=cm["gal_snr_min"], + snr_max=cm["gal_snr_max"], + rel_size_min=cm["gal_rel_size_min"], + rel_size_max=cm["gal_rel_size_max"], + size_corr_ell=cm["gal_size_corr_ell"], + sigma_eps=cm["sigma_eps_prior"], + global_R_weight=cm["global_R_weight"], + col_2d=False, + verbose=True, + ) + + g_corr_mc, g_uncorr, w, mask_metacal, c, c_err = ( + calibration.get_calibrated_m_c(gal_metacal) + ) + + # Convert index array to boolean mask + mask_metacal_bool = np.zeros(len(dat), dtype=bool) + mask_metacal_bool[mask_metacal] = True + + mask_tmp._mask = mask_metacal_bool + +masks.append(mask_tmp) + + +# %% +# Plot magnitude histograms for various selection criteria + +# Define magnitude bins +mag_bins = np.arange(15, 30, 0.05) +mag_centers = 0.5 * (mag_bins[:-1] + mag_bins[1:]) + + +# %% +# Create figure with multiple subplots +figsize = 10 +alpha = 0.5 + +col_names = ["combined", "N_EPOCH", "npoint3", "metacal"] + +# Define explicit colors for each histogram +colors = ['C0', 'C1', 'C2', 'C3'] # Use matplotlib default color cycle +color_map = dict(zip(col_names, colors)) + + +# %% +# If hist_data not loaded, compute it +if hist_data is None: + hist_data = {} + +if dat is not None: + # Get magnitude column + mag = dat['mag'] + + mask_cumul = None + for col_name in col_names: + counts, bins, label, n_valid, mask_cumul = compute_hist( + masks=masks, + col_name=col_name, + mask_cumul=mask_cumul, + mag=mag, + bins=mag_bins + ) + hist_data[col_name] = { + 'counts': counts, + 'bins': bins, + 'label': label, + 'n_valid': n_valid, + } + +# Plot histogram data +plt.figure() +fig, (ax) = plt.subplots(1, 1, figsize=(figsize, figsize)) + +for col_name in col_names: + if col_name in hist_data: + data = hist_data[col_name] + plot_hist( + data['counts'], + data['bins'], + data['label'], + alpha=alpha, + ax=ax, + color=color_map[col_name] + ) + #print(f"{col_name}: n_valid = {data['n_valid']}") + +ax.set_xlabel('$r$') +ax.set_ylabel('Number') +ax.set_xlim(17.5, 26.5) +ax.legend() + +plt.tight_layout() +plt.savefig('magnitude_histograms.png', dpi=150, bbox_inches='tight') + +# Save histogram data to file (only if we computed it) +if dat is not None: + np.savez( + hist_data_path, + **{ + key: np.array( + [ + val['counts'], + val['bins'], + val['label'], + val["n_valid"], + ], + dtype=object + ) + for key, val in hist_data.items() + } + ) + print(f"Histogram data saved to {hist_data_path}") + +# %% +if dat is not None: + obj.close_hd5() +# %% diff --git a/src/sp_validation/__init__.py b/src/sp_validation/__init__.py index b31411d..62e2307 100644 --- a/src/sp_validation/__init__.py +++ b/src/sp_validation/__init__.py @@ -16,14 +16,14 @@ ] # Explicit imports to avoid circular issues -from . import util -from . import io -from . import basic -from . import galaxy -from . import cosmology -from . import calibration -from . import cat -from . import plot_style -from . import plots -from . import run_joint_cat -from . import survey \ No newline at end of file +#from . import util +#from . import io +#from . import basic +#from . import galaxy +#from . import cosmology +#from . import calibration +#from . import cat +#from . import plot_style +#from . import plots +#from . import run_joint_cat +#from . import survey diff --git a/src/sp_validation/basic.py b/src/sp_validation/basic.py index 9c0972e..3720004 100644 --- a/src/sp_validation/basic.py +++ b/src/sp_validation/basic.py @@ -15,8 +15,6 @@ from scipy.spatial import cKDTree from scipy.special import gamma -import matplotlib.pyplot as plt - from tqdm import tqdm import operator as op import itertools as itools diff --git a/src/sp_validation/plots.py b/src/sp_validation/plots.py index d9c5d0e..d565048 100644 --- a/src/sp_validation/plots.py +++ b/src/sp_validation/plots.py @@ -669,3 +669,86 @@ def hsp_map_logical_or(maps, verbose=False): ) return map_comb + + +def plot_area_mask(ra, dec, zoom, mask=None): + """Plot Area Mask. + + Create sky plot of objects. + + Parameters + ---------- + ra : list + R.A. coordinates + dec : list + Dec. coordinates + zoom : TBD + mask: TBD, optional + + """ + if mask is None: + mask == np.ones_like(ra) + + fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30,15)) + axes[0].hexbin(ra[mask], dec[mask], gridsize=100) + axes[1].hexbin(ra[mask & zoom], dec[mask & zoom], gridsize=200) + for idx in (0, 1): + axes[idx].set_xlabel("R.A. [deg]") + axes[idx].set_ylabel("Dec [deg]") + + +def sky_plots(dat, masks, labels, zoom_ra, zoom_dec): + """Sky Plots. + + Plot sky regions with different masks. + + Parameters + ---------- + masks : list + masks to be applied + labels : dict + labels for masks + zoom_ra : list + min and max R.A. for zoom-in plot + zoom_dec : list + min and max Dec. for zoom-in plot + + """ + ra = dat["RA"][:] + dec = dat["Dec"][:] + + zoom_ra = (room_ra[0] < dat["RA"]) & (dat["RA"] < zoom_ra[1]) + zoom_dec = (zoom_dec[0] < dat["Dec"]) & (dat["Dec"] < zoom_dec[1]) + zoom = zoom_ra & zoom_dec + + # No mask + plot_area_mask(ra, dec, zoom) + + # SExtractor and SP flags + m_flags = masks[labels["FLAGS"]]._mask & masks[labels["IMAFLAGS_ISO"]]._mask + plot_area_mask(ra, dec, zoom, mask=m_flags) + + # Overlap regions + m_over = masks[labels["overlap"]]._mask & m_flags + plot_area_mask(ra, dec, zoom, mask=m_over) + + # Coverage mask + m_point = masks[labels["npoint3"]]._mask & m_over + plot_area_mask(ra, dec, zoom, mask=m_point) + + # Maximask + m_maxi = masks[labels["1024_Maximask"]]._mask & m_point + plot_area_mask(ra, dec, zoom, mask=m_maxi) + + m_comb = mask_combined._mask + plot_area_mask(ra, dec, zoom, mask=m_comb) + + m_man = m_maxi & masks[labels["8_Manual"]]._mask + plot_area_mask(ra, dec, zoom, mask=m_man) + + m_halos = ( + m_maxi + & masks[labels['1_Faint_star_halos']]._mask + & masks[labels['2_Bright_star_halos']]._mask + ) + plot_area_mask(ra, dec, zoom, mask=m_halos) diff --git a/src/sp_validation/run_joint_cat.py b/src/sp_validation/run_joint_cat.py index 9177eeb..c9d5965 100644 --- a/src/sp_validation/run_joint_cat.py +++ b/src/sp_validation/run_joint_cat.py @@ -1181,89 +1181,6 @@ def run(self): """ -def sky_plots(dat, masks, labels, zoom_ra, zoom_dec): - """Sky Plots. - - Plot sky regions with different masks. - - Parameters - ---------- - masks : list - masks to be applied - labels : dict - labels for masks - zoom_ra : list - min and max R.A. for zoom-in plot - zoom_dec : list - min and max Dec. for zoom-in plot - - """ - ra = dat["RA"][:] - dec = dat["Dec"][:] - - zoom_ra = (room_ra[0] < dat["RA"]) & (dat["RA"] < zoom_ra[1]) - zoom_dec = (zoom_dec[0] < dat["Dec"]) & (dat["Dec"] < zoom_dec[1]) - zoom = zoom_ra & zoom_dec - - # No mask - plot_area_mask(ra, dec, zoom) - - # SExtractor and SP flags - m_flags = masks[labels["FLAGS"]]._mask & masks[labels["IMAFLAGS_ISO"]]._mask - plot_area_mask(ra, dec, zoom, mask=m_flags) - - # Overlap regions - m_over = masks[labels["overlap"]]._mask & m_flags - plot_area_mask(ra, dec, zoom, mask=m_over) - - # Coverage mask - m_point = masks[labels["npoint3"]]._mask & m_over - plot_area_mask(ra, dec, zoom, mask=m_point) - - # Maximask - m_maxi = masks[labels["1024_Maximask"]]._mask & m_point - plot_area_mask(ra, dec, zoom, mask=m_maxi) - - m_comb = mask_combined._mask - plot_area_mask(ra, dec, zoom, mask=m_comb) - - m_man = m_maxi & masks[labels["8_Manual"]]._mask - plot_area_mask(ra, dec, zoom, mask=m_man) - - m_halos = ( - m_maxi - & masks[labels['1_Faint_star_halos']]._mask - & masks[labels['2_Bright_star_halos']]._mask - ) - plot_area_mask(ra, dec, zoom, mask=m_halos) - - -def plot_area_mask(ra, dec, zoom, mask=None): - """Plot Area Mask. - - Create sky plot of objects. - - Parameters - ---------- - ra : list - R.A. coordinates - dec : list - Dec. coordinates - zoom : TBD - mask: TBD, optional - - """ - if mask is None: - mask == np.ones_like(ra) - - fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30,15)) - axes[0].hexbin(ra[mask], dec[mask], gridsize=100) - axes[1].hexbin(ra[mask & zoom], dec[mask & zoom], gridsize=200) - for idx in (0, 1): - axes[idx].set_xlabel("R.A. [deg]") - axes[idx].set_ylabel("Dec [deg]") - - def confusion_matrix(mask, confidence_level=0.9): n_key = len(mask) @@ -1361,7 +1278,7 @@ class Mask(): """ - def __init__(self, col_name, label, kind="equal", value=0, dat=None, verbose=False): + def __init__(self, col_name, label, kind=None, value=0, dat=None, verbose=False): self._col_name = col_name self._label = label @@ -1439,8 +1356,7 @@ def print_strings(cls, coln, lab, num, fnum, f_out=None): print(msg) if f_out: print(msg, file=f_out) - - + def print_stats(self, num_obj, f_out=None): if self._num_ok is None: self._num_ok = sum(self._mask) @@ -1461,10 +1377,12 @@ def get_sign(self): elif self._kind =="smaller_equal": sign = "<=" return sign - - def print_summary(self, f_out): - print(f"[{self._label}]\t\t\t", file=f_out, end="") - + + def print_condition(self, f_out): + + if self._value is None: + return "" + sign = self.get_sign() if sign is not None: @@ -1472,6 +1390,10 @@ def print_summary(self, f_out): if self._kind == "range": print(f"{self._value[0]} <= {self._col_name} <= {self._value[1]}", file=f_out) + + def print_summary(self, f_out): + print(f"[{self._label}]\t\t\t", file=f_out, end="") + self.print_condition(f_out) def create_descr(self): """Create Descr. From 9a39d94a83dbfe5f9fa5259fe5da3d3d0cace56c Mon Sep 17 00:00:00 2001 From: martinkilbinger Date: Fri, 14 Nov 2025 11:29:43 +0100 Subject: [PATCH 2/3] Improved mask labels for plots --- config/calibration/mask_v1.X.6.yaml | 6 +- .../cosmo_val/catalog_paper_plot/hist_mag.py | 283 ++++++++++++++---- src/sp_validation/run_joint_cat.py | 32 +- 3 files changed, 244 insertions(+), 77 deletions(-) diff --git a/config/calibration/mask_v1.X.6.yaml b/config/calibration/mask_v1.X.6.yaml index c3c7b25..9dee194 100644 --- a/config/calibration/mask_v1.X.6.yaml +++ b/config/calibration/mask_v1.X.6.yaml @@ -31,13 +31,13 @@ dat: # Number of epochs - col_name: N_EPOCH - label: r"$n_{\rm epoch}$" + label: $n_{\rm epoch}$ kind: greater_equal value: 2 # Magnitude range - col_name: mag - label: mag range + label: r kind: range value: [15, 30] @@ -86,7 +86,7 @@ dat_ext: # Rough pointing coverage - col_name: npoint3 - label: r"$n_{\rm pointing}$" + label: $n_{\rm pointing}$ kind: greater_equal value: 3 diff --git a/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py index 3ee2c88..4ced303 100644 --- a/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py +++ b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py @@ -36,14 +36,24 @@ config = obj.read_config_set_params("config_mask.yaml") -hist_data_path = "magnitude_histograms_data.npz" - -test_only = False +test_only = True # %% +# Funcitons def get_data(obj, test_only=False): - - + """Get Data. + + Returns catalogue. + + Parameters + ---------- + obj : CalibrateCat instance + Instance of CalibrateCat class + test_only : bool, optional + If True, only load a subset of data for testing; + default is False. + + """ # Get data. Set load_into_memory to False for very large files dat, dat_ext = obj.read_cat(load_into_memory=False) @@ -58,6 +68,8 @@ def get_data(obj, test_only=False): def read_hist_data(hist_data_path): """ + Read Hist Data. + Read histogram data from npz file. Parameters @@ -72,6 +84,7 @@ def read_hist_data(hist_data_path): - 'counts': histogram counts - 'bins': bin edges - 'label': label for the histogram + """ loaded = np.load(hist_data_path, allow_pickle=True) hist_data = {} @@ -87,13 +100,29 @@ def read_hist_data(hist_data_path): return hist_data - def get_mask(masks, col_name): - + """Get Mask. + + Returns mask corresponding to col_name. + + Parameters + ---------- + masks : list + List of mask objects + col_name : str + Column name to identify the mask + Returns + ------- + list + Mask object + integer + Mask position in list + + """ # Get mask fomr masks with col_name = col_name - for mask in masks: + for idx, mask in enumerate(masks): if mask._col_name == col_name: - return mask + return mask, idx def compute_hist(masks, col_name, mask_cumul, mag, bins): @@ -125,12 +154,12 @@ def compute_hist(masks, col_name, mask_cumul, mag, bins): Number of valid data points after masking mask_cumul : array Updated cumulative mask array + """ - this_mask = get_mask(masks, col_name) + this_mask, _ = get_mask(masks, col_name) # First time: if mask_cumul is None: - print("Init mask_cumul") mask_cumul = this_mask._mask else: mask_cumul &= this_mask._mask @@ -143,14 +172,15 @@ def compute_hist(masks, col_name, mask_cumul, mag, bins): # Label string_buffer = StringIO() - this_mask.print_condition(string_buffer) + this_mask.print_condition(string_buffer, latex=True) label = string_buffer.getvalue().strip() if label == "": label = col_name + print("MKDEBUG", label) counts, bin_edges = np.histogram(my_mag, bins=bins) - return counts, bin_edges, label, n_valid, mask_cumul + return counts, bin_edges, rf"{label}", n_valid, mask_cumul def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): @@ -184,7 +214,68 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): color=color ) + +def plot_all_hists( + hist_data, + col_names, + figsize=10, + alpha=0.5, + color_map=None, + fraction=False, + ax=None, + out_path=None, +): + + if ax is None: + plt.figure() + fig, (ax) = plt.subplots( + 1, + 1, + figsize=(figsize, figsize) + ) + + counts0 = None + for col_name in col_names: + if col_name in hist_data: + + data = hist_data[col_name] + if fraction: + if counts0 is None: + counts0 = data["counts"] + counts = data["counts"] / counts0 + else: + counts = data["counts"] + + plot_hist( + counts, + data['bins'], + data['label'], + alpha=alpha, + ax=ax, + color=color_map[col_name] + ) + #print(f"{col_name}: n_valid = {data['n_valid']}") + + ax.set_xlabel('$r$') + ylabel = "fraction" if fraction else "number" + ax.set_ylabel(ylabel) + ax.set_xlim(17.5, 26.5) + if not fraction: + ax.legend() + + if out_path: + plt.tight_layout() + plt.savefig( + out_path, + dpi=150, + bbox_inches='tight' + ) + + # %% +# Main program +scenario = 1 +hist_data_path = f"magnitude_histograms_data_scenario-{scenario}.npz" if os.path.exists(hist_data_path): print(f"Histogram data file {hist_data_path} found.") @@ -201,8 +292,8 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): hist_data = None -# ## Masking # %% +# Masking # Get all masks, with or without dat, dat_ext masks, labels = sp_joint.get_masks_from_config( config, @@ -212,22 +303,89 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): ) # %% +# Combine mask according to scenario # List of basic masks to apply to all cases -masks_labels_basic = [ - "FLAGS", - "overlap", - "IMAFLAGS_ISO", - "mag", - "NGMIX_MOM_FAIL", - "NGMIX_ELL_PSFo_NOSHEAR_0", - "NGMIX_ELL_PSFo_NOSHEAR_1", - "4_Stars", - "8_Manual", - "64_r", - "1024_Maximask", -] + +masks_labels_basic = ["overlap", "mag", "64_r"] +col_names = ["basic masks"] + +if scenario == 0: + masks_labels_basic.extend([ + "FLAGS", + "IMAFLAGS_ISO", + "NGMIX_MOM_FAIL", + "NGMIX_ELL_PSFo_NOSHEAR_0", + "NGMIX_ELL_PSFo_NOSHEAR_1", + "4_Stars", + "8_Manual", + "1024_Maximask", + ]) + + col_names.extend(["N_EPOCH", "npoint3", "metacal"]) + +elif scenario == 1: + + col_names.extend([ + "IMAFLAGS_ISO", + "FLAGS", + "NGMIX_MOM_FAIL", + "NGMIX_ELL_PSFo_NOSHEAR_0", + "NGMIX_ELL_PSFo_NOSHEAR_1", + "4_Stars", + "8_Manual", + "1024_Maximask", + "N_EPOCH", + "npoint3", + "metacal", + ]) + + combine_cols = { + "ngmix failures": [ + "NGMIX_MOM_FAIL", + "NGMIX_ELL_PSFo_NOSHEAR_0", + "NGMIX_ELL_PSFo_NOSHEAR_1", + ] + } + +# %% +# Combine columns if specified. +# Remove old columns after combining. +if combine_cols is not None: + for new_col, old_cols in combine_cols.items(): + print(f"Combining columns {old_cols} into {new_col}") + # Create combined mask + old_masks = [] + idx_first = None + for col in old_cols: + mask, idx = get_mask(masks, col) + old_masks.append(mask) + if idx_first is None: + idx_first = idx + if dat is not None: + print(f"Creating combined mask for {new_col}") + masks_combined = sp_joint.Mask.from_list( + old_masks, + label=new_col, + verbose=obj._params["verbose"], + ) + else: + print(f"Creating dummy mask for {new_col} (for plot label)") + masks_combined = sp_joint.Mask( + new_col, + new_col, + kind="none", + ) + masks.insert(idx, masks_combined) + col_names.insert(idx, new_col) + + for old_mask, old_col in zip(old_masks, old_cols): + masks.remove(old_mask) + col_names.remove(old_col) + + print("After combining: masks =", [mask._col_name for mask in masks]) # %% +# Createe list of masks masks_basic = [] for mask in masks: if mask._col_name in masks_labels_basic: @@ -237,15 +395,15 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): print("Creating combined basic mask") masks_basic_combined = sp_joint.Mask.from_list( masks_basic, - label="combined", + label="basic masks", verbose=obj._params["verbose"], ) else: # Create dummy combined mask (for plot label) print("Creating dummy combined basic mask") masks_basic_combined = sp_joint.Mask( - "combined", - "combined", + "basic masks", + "basic masks", kind="none", ) @@ -260,6 +418,7 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): ) # %% +# Call metacal if data is available if dat is not None: cm = config["metacal"] gal_metacal = metacal( @@ -290,22 +449,17 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): # %% -# Plot magnitude histograms for various selection criteria - # Define magnitude bins mag_bins = np.arange(15, 30, 0.05) mag_centers = 0.5 * (mag_bins[:-1] + mag_bins[1:]) - # %% # Create figure with multiple subplots figsize = 10 alpha = 0.5 -col_names = ["combined", "N_EPOCH", "npoint3", "metacal"] - # Define explicit colors for each histogram -colors = ['C0', 'C1', 'C2', 'C3'] # Use matplotlib default color cycle +colors = [f'C{i}' for i in range(len(col_names))] # Use matplotlib default color cycle color_map = dict(zip(col_names, colors)) @@ -333,32 +487,38 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): 'label': label, 'n_valid': n_valid, } - +# %% +# Create plots +fig, axes = plt.subplots( + 1, + 2, + figsize=(2 * figsize, figsize) +) # Plot histogram data -plt.figure() -fig, (ax) = plt.subplots(1, 1, figsize=(figsize, figsize)) - -for col_name in col_names: - if col_name in hist_data: - data = hist_data[col_name] - plot_hist( - data['counts'], - data['bins'], - data['label'], - alpha=alpha, - ax=ax, - color=color_map[col_name] - ) - #print(f"{col_name}: n_valid = {data['n_valid']}") - -ax.set_xlabel('$r$') -ax.set_ylabel('Number') -ax.set_xlim(17.5, 26.5) -ax.legend() - +plot_all_hists( + hist_data, + col_names, + alpha=alpha, + color_map=color_map, + ax=axes[0], +) +plot_all_hists( + hist_data, + col_names, + alpha=alpha, + color_map=color_map, + fraction=True, + ax=axes[1], +) plt.tight_layout() -plt.savefig('magnitude_histograms.png', dpi=150, bbox_inches='tight') +out_path = f"magnitude_histograms_scenario-{scenario}.png" +plt.savefig( + out_path, + dpi=150, + bbox_inches='tight' +) +# %% # Save histogram data to file (only if we computed it) if dat is not None: np.savez( @@ -381,4 +541,9 @@ def plot_hist(counts, bins, label, alpha=1, ax=None, color=None): # %% if dat is not None: obj.close_hd5() + +# %% +for mask in masks: + mask.print_condition(sys.stdout, latex=True) + # %% diff --git a/src/sp_validation/run_joint_cat.py b/src/sp_validation/run_joint_cat.py index c9d5965..e30bc43 100644 --- a/src/sp_validation/run_joint_cat.py +++ b/src/sp_validation/run_joint_cat.py @@ -1365,32 +1365,34 @@ def print_stats(self, num_obj, f_out=None): sf = f"{self._num_ok/num_obj:10.2%}" self.print_strings(self._col_name, self._label, si, sf, f_out=f_out) - def get_sign(self): + def get_sign(self, latex=False): sign = None - if self._kind =="equal": - sign = "=" - elif self._kind =="not_equal": - sign = "!=" - elif self._kind =="greater_equal": - sign = ">=" - elif self._kind =="smaller_equal": - sign = "<=" + if self._kind == "equal": + sign = "$=$" if latex else "=" + elif self._kind == "not_equal": + sign = "$\ne$" if latex else "!=" + elif self._kind in ("greater_equal", "range"): + sign = "$\leq$" if latex else ">=" + elif self._kind == "smaller_equal": + sign = "$\geq$" if latex else "<=" return sign - def print_condition(self, f_out): + def print_condition(self, f_out, latex=False): if self._value is None: return "" - sign = self.get_sign() + sign = self.get_sign(latex=latex) + + name = self._label if latex else self._col_name if sign is not None: - print(f"{self._col_name} {sign} {self._value}", file=f_out) - + print(f"{name} {sign} {self._value}", file=f_out) + if self._kind == "range": - print(f"{self._value[0]} <= {self._col_name} <= {self._value[1]}", file=f_out) - + print(f"{self._value[0]} {sign} {name} {sign} {self._value[1]}", file=f_out) + def print_summary(self, f_out): print(f"[{self._label}]\t\t\t", file=f_out, end="") self.print_condition(f_out) From f4f0bb99df63569837bfa19cd3bdad4a76791791 Mon Sep 17 00:00:00 2001 From: martinkilbinger Date: Mon, 17 Nov 2025 13:13:17 +0100 Subject: [PATCH 3/3] hist plot, basic: mask size and SNR separately --- .../cosmo_val/catalog_paper_plot/hist_mag.py | 23 ++++++++++++++++ src/sp_validation/basic.py | 26 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py index 4ced303..7d202d7 100644 --- a/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py +++ b/notebooks/cosmo_val/catalog_paper_plot/hist_mag.py @@ -417,6 +417,25 @@ def plot_all_hists( kind="none", ) +# %% +def get_info_for_metacal_masking(dat, mask, prefix = "NGMIX", name_shear = "NOSHEAR"): + + res = {} + + res["flag"] = dat[mask][f"{prefix}_FLAGS_{name_shear}"] + + for key in ("flux", "flux_err", "T"): + res[key] = dat[mask][f"{prefix}_{key.upper()}_{name_shear}"] + res["Tpsf"] = dat[mask][f"{prefix}_Tpsf_{name_shear}"] + + return res + +# %% +if dat is not None: + cm = config["metacal"] + + + # %% # Call metacal if data is available if dat is not None: @@ -547,3 +566,7 @@ def plot_all_hists( mask.print_condition(sys.stdout, latex=True) # %% +# print number of valid objects and name +for data in hist_data + +# %% diff --git a/src/sp_validation/basic.py b/src/sp_validation/basic.py index 3720004..fc96230 100644 --- a/src/sp_validation/basic.py +++ b/src/sp_validation/basic.py @@ -597,3 +597,29 @@ def jackknif_weighted_average2( all_est = np.array(all_est) return np.mean(all_est), np.std(all_est) + + +def mask_gal_size(T, Tpsf, rel_size_min, rel_size_max, size_corr_ell=False, g1=None, g2=None): + + Tr_tmp = T + if size_corr_ell: + Tr_tmp *= ( + (1 - g1 **2 + g2 ** 2) / (1 + g1 ** 2 + g2 **2) + ) + + mask = ( + (Tr_tmp / Tpsf > rel_size_min) + & (Tr_tmp / Tpsf < rel_size_max) + ) + + return mask + + +def mask_gal_SNR(SNR, snr_min, snr_max): + + mask = ( + (SNR > snr_min) + & (SNR < snr_max) + ) + + return mask