Skip to content

Commit ab66321

Browse files
author
Corentin
committed
adding erosion setting
1 parent 6df17c4 commit ab66321

File tree

6 files changed

+435
-341
lines changed

6 files changed

+435
-341
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,5 @@ debug_data/
169169
!nuclei.tif
170170
!cytoplasm.tif
171171
!binary_mask_sdh.tif
172-
data/*
172+
data/*
173+
sample_atp_intensity_plot.png

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"request": "launch",
2727
"module": "myoquant",
2828
"justMyCode": true,
29-
"args": ["atp-analysis", "sample_img/sample_atp.jpg", "--cellpose-path", "sample_img/sample_atp_cellpose_mask.tiff"],
29+
"args": ["atp-analysis", "sample_img/sample_atp.jpg", "--cellpose-path", "sample_img/sample_atp_cellpose_mask.tiff", "--intensity-method", "mean", "--n-classes", "2", "--erosion"],
3030
}
3131
]
3232
}

myoquant/commands/run_atp.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def atp_analysis(
5959
None,
6060
help="Approximative single cell diameter in pixel for CellPose detection. If not specified, Cellpose will try to deduce it.",
6161
),
62+
channel: int = typer.Option(
63+
None,
64+
help="Image channel to use for the analysis. If not specified, the analysis will be performed on all three channels.",
65+
),
66+
n_classes: int = typer.Option(
67+
2,
68+
max=10,
69+
help="The number of classes of cell to detect. If not specified this is defaulted to two classes.",
70+
),
71+
intensity_method: str = typer.Option(
72+
"median",
73+
help="The method to use to compute the intensity of the cell. Can be either 'median' or 'mean'.",
74+
),
75+
erosion: bool = typer.Option(
76+
False,
77+
help="Perform an erosion on the cells images to remove signal in the cell membrane (usefull for fluo)",
78+
),
6279
export_map: bool = typer.Option(
6380
True,
6481
help="Export the original image with cells painted by classification label.",
@@ -135,6 +152,8 @@ def atp_analysis(
135152
) as progress:
136153
progress.add_task(description="Reading all inputs...", total=None)
137154
image_ndarray = imread(image_path)
155+
if channel is not None:
156+
image_ndarray = image_ndarray[:, :, channel]
138157

139158
if mask_path is not None:
140159
mask_ndarray = imread(mask_path)
@@ -200,8 +219,13 @@ def atp_analysis(
200219
transient=False,
201220
) as progress:
202221
progress.add_task(description="Detecting fiber types...", total=None)
203-
result_df, full_label_map, df_cellpose_details = run_atp_analysis(
204-
image_ndarray, mask_cellpose, intensity_threshold
222+
result_df, full_label_map, df_cellpose_details, fig = run_atp_analysis(
223+
image_ndarray,
224+
mask_cellpose,
225+
intensity_threshold,
226+
n_classes,
227+
intensity_method,
228+
erosion,
205229
)
206230
if export_map:
207231
with Progress(
@@ -214,7 +238,12 @@ def atp_analysis(
214238
description="Blending label and original image together...", total=None
215239
)
216240
labelRGB_map = label2rgb(image_ndarray, full_label_map)
217-
overlay_img = blend_image_with_label(image_ndarray, labelRGB_map)
241+
if channel is not None:
242+
overlay_img = blend_image_with_label(
243+
image_ndarray, labelRGB_map, fluo=True
244+
)
245+
else:
246+
overlay_img = blend_image_with_label(image_ndarray, labelRGB_map)
218247
overlay_filename = image_path.stem + "_label_blend.tiff"
219248
overlay_img.save(output_path / overlay_filename)
220249

@@ -239,6 +268,11 @@ def atp_analysis(
239268
f"💾 OUTPUT: Summary Table saved as {output_path/csv_name}",
240269
style="green",
241270
)
271+
plot_name = image_path.stem + "_intensity_plot.png"
272+
fig.savefig(output_path / plot_name)
273+
console.print(
274+
f"💾 OUTPUT: Intensity Plot saved as {output_path/plot_name}", style="green"
275+
)
242276
if export_map:
243277
console.print(
244278
f"💾 OUTPUT: Overlay image saved as {output_path/overlay_filename}",

myoquant/src/ATP_analysis.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,65 @@
33
from sklearn.mixture import GaussianMixture
44
from .common_func import extract_single_image, df_from_cellpose_mask
55
import numpy as np
6+
import matplotlib
7+
8+
matplotlib.use("agg")
9+
610
import matplotlib.pyplot as plt
711

812
labels_predict = {1: "fiber type 1", 2: "fiber type 2"}
913
np.random.seed(42)
1014

1115

12-
def get_all_intensity(image_array, df_cellpose):
16+
def get_all_intensity(
17+
image_array, df_cellpose, intensity_method="median", erosion=False
18+
):
1319
all_cell_median_intensity = []
1420
for index in range(len(df_cellpose)):
15-
single_cell_img = extract_single_image(image_array, df_cellpose, index)
21+
single_cell_img = extract_single_image(image_array, df_cellpose, index, erosion)
1622

1723
# Calculate median pixel intensity of the cell but ignore 0 values
18-
single_cell_median_intensity = np.median(single_cell_img[single_cell_img > 0])
24+
if intensity_method == "median":
25+
single_cell_median_intensity = np.median(
26+
single_cell_img[single_cell_img > 0]
27+
)
28+
elif intensity_method == "mean":
29+
single_cell_median_intensity = np.mean(single_cell_img[single_cell_img > 0])
1930
all_cell_median_intensity.append(single_cell_median_intensity)
2031
return all_cell_median_intensity
2132

2233

23-
def estimate_threshold(intensity_list):
34+
def estimate_threshold(intensity_list, n_classes=2):
2435
density = gaussian_kde(intensity_list)
2536
density.covariance_factor = lambda: 0.25
2637
density._compute_covariance()
2738

2839
# Create a vector of 256 values going from 0 to 256:
2940
xs = np.linspace(0, 255, 256)
3041
density_xs_values = density(xs)
31-
gmm = GaussianMixture(n_components=2).fit(np.array(intensity_list).reshape(-1, 1))
42+
gmm = GaussianMixture(n_components=n_classes).fit(
43+
np.array(intensity_list).reshape(-1, 1)
44+
)
3245

3346
# Find the x values of the two peaks
3447
peaks_x = np.sort(gmm.means_.flatten())
3548
# Find the minimum point between the two peaks
36-
min_index = np.argmin(density_xs_values[(xs > peaks_x[0]) & (xs < peaks_x[1])])
37-
threshold = peaks_x[0] + xs[min_index]
3849

39-
return threshold
50+
threshold_list = []
51+
length = len(peaks_x)
52+
for index, peaks in enumerate(peaks_x):
53+
if index == length - 1:
54+
break
55+
min_index = np.argmin(
56+
density_xs_values[(xs > peaks) & (xs < peaks_x[index + 1])]
57+
)
58+
threshold_list.append(peaks + xs[min_index])
59+
return threshold_list
4060

4161

42-
def plot_density(all_cell_median_intensity, intensity_threshold):
62+
def plot_density(all_cell_median_intensity, intensity_threshold, n_classes=2):
4363
if intensity_threshold == 0:
44-
intensity_threshold = estimate_threshold(all_cell_median_intensity)
64+
intensity_threshold = estimate_threshold(all_cell_median_intensity, n_classes)
4565
fig, ax = plt.subplots(figsize=(10, 5))
4666
density = gaussian_kde(all_cell_median_intensity)
4767
density.covariance_factor = lambda: 0.25
@@ -51,22 +71,44 @@ def plot_density(all_cell_median_intensity, intensity_threshold):
5171
xs = np.linspace(0, 255, 256)
5272
density_xs_values = density(xs)
5373
ax.plot(xs, density_xs_values, label="Estimated Density")
54-
ax.axvline(x=intensity_threshold, color="red", label="Threshold")
74+
for values in intensity_threshold:
75+
ax.axvline(x=values, color="red", label="Threshold")
5576
ax.set_xlabel("Pixel Intensity")
5677
ax.set_ylabel("Density")
5778
ax.legend()
5879
return fig
5980

6081

61-
def predict_all_cells(histo_img, cellpose_df, intensity_threshold):
62-
all_cell_median_intensity = get_all_intensity(histo_img, cellpose_df)
82+
def merge_peaks_too_close(peak_list):
83+
pass
84+
85+
86+
def classify_cells_intensity(all_cell_median_intensity, intensity_threshold):
87+
muscle_fiber_type_all = []
88+
for intensity in all_cell_median_intensity:
89+
class_cell = np.searchsorted(intensity_threshold, intensity, side="right")
90+
muscle_fiber_type_all.append(class_cell)
91+
return muscle_fiber_type_all
92+
93+
94+
def predict_all_cells(
95+
histo_img,
96+
cellpose_df,
97+
intensity_threshold,
98+
n_classes=2,
99+
intensity_method="median",
100+
erosion=False,
101+
):
102+
all_cell_median_intensity = get_all_intensity(
103+
histo_img, cellpose_df, intensity_method, erosion
104+
)
63105
if intensity_threshold is None:
64-
intensity_threshold = estimate_threshold(all_cell_median_intensity)
106+
intensity_threshold = estimate_threshold(all_cell_median_intensity, n_classes)
65107

66-
muscle_fiber_type_all = [
67-
1 if x > intensity_threshold else 2 for x in all_cell_median_intensity
68-
]
69-
return muscle_fiber_type_all, all_cell_median_intensity
108+
muscle_fiber_type_all = classify_cells_intensity(
109+
all_cell_median_intensity, intensity_threshold
110+
)
111+
return muscle_fiber_type_all, all_cell_median_intensity, intensity_threshold
70112

71113

72114
def paint_full_image(image_atp, df_cellpose, class_predicted_all):
@@ -76,24 +118,46 @@ def paint_full_image(image_atp, df_cellpose, class_predicted_all):
76118
# for index in track(range(len(df_cellpose)), description="Painting cells"):
77119
for index in range(len(df_cellpose)):
78120
single_cell_mask = df_cellpose.iloc[index, 9].copy()
79-
if class_predicted_all[index] == 1:
80-
image_atp_paint[
81-
df_cellpose.iloc[index, 5] : df_cellpose.iloc[index, 7],
82-
df_cellpose.iloc[index, 6] : df_cellpose.iloc[index, 8],
83-
][single_cell_mask] = 1
84-
elif class_predicted_all[index] == 2:
85-
image_atp_paint[
86-
df_cellpose.iloc[index, 5] : df_cellpose.iloc[index, 7],
87-
df_cellpose.iloc[index, 6] : df_cellpose.iloc[index, 8],
88-
][single_cell_mask] = 2
121+
image_atp_paint[
122+
df_cellpose.iloc[index, 5] : df_cellpose.iloc[index, 7],
123+
df_cellpose.iloc[index, 6] : df_cellpose.iloc[index, 8],
124+
][single_cell_mask] = (
125+
class_predicted_all[index] + 1
126+
)
89127
return image_atp_paint
90128

91129

92-
def run_atp_analysis(image_array, mask_cellpose, intensity_threshold=None):
130+
def label_list_from_threhsold(threshold_list):
131+
label_list = []
132+
length = len(threshold_list)
133+
for index, threshold in enumerate(threshold_list):
134+
if index == 0:
135+
label_list.append(f"<{threshold}")
136+
if index == length - 1:
137+
label_list.append(f">{threshold}")
138+
else:
139+
label_list.append(f">{threshold} & <{threshold_list[index+1]}")
140+
return label_list
141+
142+
143+
def run_atp_analysis(
144+
image_array,
145+
mask_cellpose,
146+
intensity_threshold=None,
147+
n_classes=2,
148+
intensity_method="median",
149+
erosion=False,
150+
):
93151
df_cellpose = df_from_cellpose_mask(mask_cellpose)
94-
class_predicted_all, intensity_all = predict_all_cells(
95-
image_array, df_cellpose, intensity_threshold
152+
class_predicted_all, intensity_all, intensity_threshold = predict_all_cells(
153+
image_array,
154+
df_cellpose,
155+
intensity_threshold,
156+
n_classes,
157+
intensity_method,
158+
erosion,
96159
)
160+
fig = plot_density(intensity_all, intensity_threshold, n_classes)
97161
df_cellpose["muscle_cell_type"] = class_predicted_all
98162
df_cellpose["cell_intensity"] = intensity_all
99163
count_per_label = np.unique(class_predicted_all, return_counts=True)
@@ -102,15 +166,16 @@ def run_atp_analysis(image_array, mask_cellpose, intensity_threshold=None):
102166
headers = ["Feature", "Raw Count", "Proportion (%)"]
103167
data = []
104168
data.append(["Muscle Fibers", len(class_predicted_all), 100])
169+
label_list = label_list_from_threhsold(intensity_threshold)
105170
for index, elem in enumerate(count_per_label[0]):
106171
data.append(
107172
[
108-
labels_predict[int(elem)],
173+
label_list[int(elem)],
109174
count_per_label[1][int(index)],
110175
100 * count_per_label[1][int(index)] / len(class_predicted_all),
111176
]
112177
)
113178
result_df = pd.DataFrame(columns=headers, data=data)
114179
# Paint The Full Image
115180
full_label_map = paint_full_image(image_array, df_cellpose, class_predicted_all)
116-
return result_df, full_label_map, df_cellpose
181+
return result_df, full_label_map, df_cellpose, fig

myoquant/src/common_func.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
77
import sys
8+
import math
9+
810
import tensorflow as tf
911
import torch
1012
from cellpose.models import Cellpose
@@ -14,11 +16,14 @@
1416
import numpy as np
1517
from PIL import Image
1618
from skimage.measure import regionprops_table
19+
from skimage.morphology import binary_erosion
1720
import pandas as pd
1821

1922
# from .gradcam import make_gradcam_heatmap, save_and_display_gradcam
2023
from .random_brightness import RandomBrightness
2124

25+
import imageio
26+
2227
tf.random.set_seed(42)
2328
np.random.seed(42)
2429

@@ -168,12 +173,21 @@ def df_from_stardist_mask(mask, intensity_image=None):
168173
return df_stardist
169174

170175

171-
def extract_single_image(raw_image, df_props, index):
176+
def extract_single_image(raw_image, df_props, index, erosion=False):
172177
single_entity_img = raw_image[
173178
df_props.iloc[index, 5] : df_props.iloc[index, 7],
174179
df_props.iloc[index, 6] : df_props.iloc[index, 8],
175180
].copy()
181+
surface_area = df_props.iloc[index, 1]
182+
cell_radius = math.sqrt(surface_area / math.pi)
176183
single_entity_mask = df_props.iloc[index, 9]
184+
erosion_size = int(cell_radius / 5) # 20% of the cell
185+
if erosion:
186+
for i in range(erosion_size):
187+
single_entity_mask = binary_erosion(
188+
single_entity_mask, out=single_entity_mask
189+
)
190+
177191
single_entity_img[~single_entity_mask] = 0
178192
return single_entity_img
179193

@@ -192,6 +206,14 @@ def label2rgb(img_ndarray, label_map):
192206
0: [255, 255, 255],
193207
1: [15, 157, 88],
194208
2: [219, 68, 55],
209+
3: [100, 128, 170],
210+
4: [231, 204, 143],
211+
5: [202, 137, 115],
212+
6: [178, 143, 172],
213+
7: [144, 191, 207],
214+
8: [148, 187, 187],
215+
9: [78, 86, 105],
216+
10: [245, 90, 87],
195217
}
196218
img_rgb = np.zeros((img_ndarray.shape[0], img_ndarray.shape[1], 3), dtype=np.uint8)
197219

0 commit comments

Comments
 (0)