diff --git a/Functional_Fusion/atlas_map.py b/Functional_Fusion/atlas_map.py index d246e87..6435c2a 100644 --- a/Functional_Fusion/atlas_map.py +++ b/Functional_Fusion/atlas_map.py @@ -132,6 +132,46 @@ def parcel_recombine(label_vector,parcels_selected,label_id=None,label_name=None raise ValueError('parcels_selected must be a list') return label_vector_new, label_id_new, label_name_new + +def parcel_combine(img, output_filename=None): + """ + Combines multiple ROI mask NIfTI files into a single NIfTI file where each ROI has a unique integer label. + + Parameters: + - roi_files (list of str or Nifti1Image): List of paths to NIfTI mask files or list of NIfTI mask files. + - output_filename (str): Path to save the combined NIfTI file. + + Returns: + - Saves a NIfTI file where each ROI has a unique label. + """ + # Load the first image to get shape and affine transformation + if isinstance(img[0], str): + reference_img = nb.load(img[0]) + if isinstance(img[0], nb.Nifti1Image): + reference_img = img[0] + combined_data = np.zeros(reference_img.shape, dtype=np.int16) + + # Assign unique labels to each ROI + for i, mask in enumerate(img, start=1): + if isinstance(mask, str): + roi_img = nb.load(mask) + if isinstance(mask, nb.Nifti1Image): + roi_img = mask + roi_data = roi_img.get_fdata() + + # Ensure binary mask (in case input masks have non-binary values) + roi_mask = roi_data > 0 + + # Assign a unique label to this ROI + combined_data[roi_mask] = i + + # Save the combined ROI mask as a new NIfTI file + combined_img = nb.Nifti1Image(combined_data, reference_img.affine, reference_img.header) + nb.save(combined_img, output_filename) + + return combined_img + + class Atlas: def __init__(self, name, structure='unknown', space='unknown'): """ The Atlas class implements the mapping from the P brain locations back to the defining @@ -1090,12 +1130,15 @@ def build(self, depths=[0, 0.2, 0.4, 0.6, 0.8, 1.0]): indices[i, :, :] = (1 - depths[i]) * c1 + depths[i] * c2 self.vox_list, good = nt.coords_to_linvidxs(indices, self.mask_img, mask=True) - all = good.sum(axis=0) + # all = good.sum(axis=0) + _, invx, count = np.unique(self.vox_list, return_inverse=True, return_counts=True) # print(f'{self.name} has {np.sum(all==0)} vertices without data') - all[all == 0] = 1 - self.vox_weight = good / all + # all[all == 0] = 1 + self.vox_weight = count[invx]# good / all self.vox_list = self.vox_list.T - self.vox_weight = self.vox_weight.T + self.vox_weight = self.vox_weight.T + + pass def get_data_nifti(fnames, atlas_maps): """Extracts the data for a list of fnames @@ -1201,9 +1244,27 @@ def exclude_overlapping_voxels(amap, exclude='all', exclude_thres=0.9): vox_j, weight_j = amap[j].vox_list, amap[j].vox_weight vox_k, weight_k = amap[k].vox_list, amap[k].vox_weight - EQ = vox_j.flatten()[:, np.newaxis] == vox_k.flatten()[np.newaxis, :] + # EQ = vox_j.flatten()[:, np.newaxis] == vox_k.flatten()[np.newaxis, :] + # + # idx_j, idx_k = np.where(EQ) + vox_j = vox_j.flatten() + vox_k = vox_k.flatten() + + # Sort vox_k and keep track of the original indices + sort_idx = np.argsort(vox_k) + vox_k_sorted = vox_k[sort_idx] + + # Check which elements in vox_j exist in vox_k + mask = np.isin(vox_j, vox_k_sorted) + + # Find the corresponding indices in vox_k + idx_j = np.where(mask)[0] # Indices in vox_j + idx_k = np.searchsorted(vox_k_sorted, vox_j[mask]) # Indices in sorted vox_k - idx_j, idx_k = np.where(EQ) + # Convert back to original vox_k indices + idx_k = sort_idx[idx_k] + + print(f'found {len(idx_j)} overlapping voxels') for idx_j_v, idx_k_v in zip(idx_j, idx_k): wj, wk = weight_j.flatten()[idx_j_v], weight_k.flatten()[idx_k_v] diff --git a/Functional_Fusion/dataset.py b/Functional_Fusion/dataset.py index 5b2fd1f..3b2f8b4 100644 --- a/Functional_Fusion/dataset.py +++ b/Functional_Fusion/dataset.py @@ -229,7 +229,7 @@ def agg_data(info, by, over, subset=None): return data_info, C -def agg_parcels(data, label_vec, fcn=np.nanmean): +def agg_parcels(data, label_vec, fcn=np.nanmean, **kwargs): """ Aggregates data over colums to condense to parcels Args: @@ -248,7 +248,7 @@ def agg_parcels(data, label_vec, fcn=np.nanmean): parcel_data = np.zeros(psize) for i, l in enumerate(labels): parcel_data[..., i] = fcn( - data[..., label_vec == l], axis=len(psize) - 1) + data[..., label_vec == l], axis=len(psize) - 1, **kwargs) return parcel_data, labels def combine_parcel_labels(labels_org,labels_new, labelvec_org=None):