Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from concurrent import futures
from functools import partial
from multiprocessing import cpu_count
from typing import List, Optional, Tuple, Union

import numpy as np
Expand All @@ -22,13 +23,13 @@
from tqdm import tqdm

from .file_utils import read_image_data
from .segmentation.postprocessing import compute_table_on_the_fly
from .postprocessing.label_components import compute_table_on_the_fly
import flamingo_tools.s3_utils as s3_utils


def _measure_volume_and_surface(mask, resolution):
# Use marching_cubes for 3D data
verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3)
verts, faces, normals, _ = marching_cubes(mask, spacing=resolution)

mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
surface = mesh.area
Expand Down Expand Up @@ -166,6 +167,8 @@ def _default_object_features(

# Do the volume and surface measurement.
if not median_only:
if isinstance(resolution, float):
resolution = (resolution,) * 3
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
Expand All @@ -181,6 +184,8 @@ def _morphology_features(seg_id, table, image, segmentation, resolution, **kwarg
# Hard-coded value for LaVision cochleae. This is a hack for the wrong voxel size in MoBIE.
# resolution = (3.0, 0.76, 0.76)

if isinstance(resolution, float):
resolution = (resolution,) * 3
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
Expand Down Expand Up @@ -498,3 +503,102 @@ def _compute_block(block_id):

mask = ResizedVolume(low_res_mask, shape=original_shape, order=0)
return mask


def object_measures_single(
table_path: str,
seg_path: str,
image_paths: List[str],
out_paths: List[str],
force_overwrite: bool = False,
component_list: List[int] = [1],
background_mask: Optional[np.typing.ArrayLike] = None,
resolution: List[float] = [0.38, 0.38, 0.38],
s3: bool = False,
s3_credentials: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_service_endpoint: Optional[str] = None,
**_
):
"""Compute object measures for a single or multiple image channels in respect to a single segmentation channel.

Args:
table_path: File path to segmentationt table.
seg_path: Path to segmentation channel in ome.zarr format.
image_paths: Path(s) to image channel(s) in ome.zarr format.
out_paths: Paths(s) for calculated object measures.
force_overwrite: Forcefully overwrite existing files.
component_list: Only calculate object measures for specific components.
background_mask: Use background mask for calculating object measures.
resolution: Resolution of input in micrometer.
s3: Use S3 file paths.
s3_credentials:
s3_bucket_name:
s3_service_endpoint:
"""
input_key = "s0"
out_paths = [os.path.realpath(o) for o in out_paths]

if not isinstance(resolution, float):
if len(resolution) == 1:
resolution = resolution * 3
assert len(resolution) == 3
resolution = np.array(resolution)[::-1]
else:
resolution = (resolution,) * 3

for (img_path, out_path) in zip(image_paths, out_paths):
n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count()))

# overwrite input file
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
force_overwrite = True

if os.path.isfile(out_path) and not force_overwrite:
print(f"Skipping {out_path}. Table already exists.")

else:
if background_mask is None:
feature_set = "default"
dilation = None
median_only = False
else:
print("Using background mask for calculating object measures.")
feature_set = "default_background_subtract"
dilation = 4
median_only = True

if s3:
img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name,
service_endpoint=s3_service_endpoint,
credential_file=s3_credentials)
seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name,
service_endpoint=s3_service_endpoint,
credential_file=s3_credentials)

mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr")
background_mask = compute_sgn_background_mask(
image_path=img_path,
segmentation_path=seg_path,
image_key=input_key,
segmentation_key=input_key,
n_threads=n_threads,
cache_path=mask_cache_path,
)

compute_object_measures(
image_path=img_path,
segmentation_path=seg_path,
segmentation_table_path=table_path,
output_table_path=out_path,
image_key=input_key,
segmentation_key=input_key,
feature_set=feature_set,
s3_flag=s3,
component_list=component_list,
dilation=dilation,
median_only=median_only,
background_mask=background_mask,
n_threads=n_threads,
resolution=resolution,
)
4 changes: 4 additions & 0 deletions flamingo_tools/postprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""This module implements the functionality to filter isolated objects from a segmentation.
"""

from .label_components import filter_segmentation
149 changes: 149 additions & 0 deletions flamingo_tools/postprocessing/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""private
"""
import argparse

from .label_components import label_components_single
from .cochlea_mapping import tonotopic_mapping_single
from flamingo_tools.measurements import object_measures_single


def label_components():
parser = argparse.ArgumentParser(
description="Script to label segmentation using a segmentation table and graph connected components.")

parser.add_argument("-i", "--input", type=str, required=True, help="Input path to segmentation table.")
parser.add_argument("-o", "--output", type=str, required=True,
help="Output path. Either directory (for --json) or specific file otherwise.")
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")

# options for post-processing
parser.add_argument("--cell_type", type=str, default="sgn",
help="Cell type of segmentation. Either 'sgn' or 'ihc'.")
parser.add_argument("--min_size", type=int, default=1000,
help="Minimal number of pixels for filtering small instances.")
parser.add_argument("--min_component_length", type=int, default=50,
help="Minimal length for filtering out connected components.")
parser.add_argument("--max_edge_distance", type=float, default=30,
help="Maximal distance in micrometer between points to create edges for connected components.")
parser.add_argument("-c", "--components", type=int, nargs="+", default=[1], help="List of connected components.")

# options for S3 bucket
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
parser.add_argument("--s3_credentials", type=str, default=None,
help="Input file containing S3 credentials. "
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
parser.add_argument("--s3_bucket_name", type=str, default=None,
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
parser.add_argument("--s3_service_endpoint", type=str, default=None,
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")

args = parser.parse_args()

label_components_single(
table_path=args.input,
out_path=args.output,
cell_type=args.cell_type,
component_list=args.components,
max_edge_distance=args.max_edge_distance,
min_component_length=args.min_component_length,
min_size=args.min_size,
force_overwrite=args.force,
s3=args.s3,
s3_credentials=args.s3_credentials,
s3_bucket_name=args.s3_bucket_name,
s3_service_endpoint=args.s3_service_endpoint,
)


def tonotopic_mapping():
parser = argparse.ArgumentParser(
description="Script to extract region of interest (ROI) block around center coordinate.")

parser.add_argument("-i", "--input", type=str, required=True, help="Input path to segmentation table.")
parser.add_argument("-o", "--output", type=str, required=True,
help="Output path. Either directory or specific file.")
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")

# options for tonotopic mapping
parser.add_argument("--animal", type=str, default="mouse",
help="Animal type to be used for frequency mapping. Either 'mouse' or 'gerbil'.")
parser.add_argument("--otof", action="store_true", help="Use frequency mapping for OTOF cochleae.")
parser.add_argument("--apex_position", type=str, default="apex_higher",
help="Use frequency mapping for OTOF cochleae.")

# options for post-processing
parser.add_argument("--cell_type", type=str, default="sgn",
help="Cell type of segmentation. Either 'sgn' or 'ihc'.")
parser.add_argument("--max_edge_distance", type=float, default=30,
help="Maximal distance in micrometer between points to create edges for connected components.")
parser.add_argument("-c", "--components", type=int, nargs="+", default=[1], help="List of connected components.")

# options for S3 bucket
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
parser.add_argument("--s3_credentials", type=str, default=None,
help="Input file containing S3 credentials. "
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
parser.add_argument("--s3_bucket_name", type=str, default=None,
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
parser.add_argument("--s3_service_endpoint", type=str, default=None,
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")

args = parser.parse_args()

tonotopic_mapping_single(
table_path=args.input,
out_path=args.output,
force_overwrite=args.force,
animal=args.animal,
otof=args.otof,
apex_position=args.apex_position,
cell_type=args.cell_type,
max_edge_distance=args.max_edge_distance,
component_list=args.components,
s3=args.s3,
s3_credentials=args.s3_credentials,
s3_bucket_name=args.s3_bucket_name,
s3_service_endpoint=args.s3_service_endpoint,
)


def object_measures():
parser = argparse.ArgumentParser(
description="Script to compute object measures for different stainings.")

parser.add_argument("-o", "--output", type=str, nargs="+", required=True,
help="Output path(s). Either directory or specific file(s).")
parser.add_argument("-i", "--image_paths", type=str, nargs="+", default=None,
help="Input path to one or multiple image channels in ome.zarr format.")
parser.add_argument("-t", "--seg_table", type=str, default=None,
help="Input path to segmentation table.")
parser.add_argument("-s", "--seg_path", type=str, default=None,
help="Input path to segmentation channel in ome.zarr format.")
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")

# options for object measures
parser.add_argument("-c", "--components", type=int, nargs="+", default=[1], help="List of components.")
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[0.38, 0.38, 0.38],
help="Resolution of input in micrometer.")
parser.add_argument("--bg_mask", action="store_true", help="Use background mask for calculating object measures.")

# options for S3 bucket
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
parser.add_argument("--s3_credentials", type=str, default=None,
help="Input file containing S3 credentials. "
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
parser.add_argument("--s3_bucket_name", type=str, default=None,
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
parser.add_argument("--s3_service_endpoint", type=str, default=None,
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")

args = parser.parse_args()

object_measures_single(
out_paths=args.output,
image_paths=args.image_paths,
table_path=args.seg_table,
seg_path=args.seg_path,
force_overwrite=args.force,
s3=args.s3,
)
Loading