Skip to content

Commit 8a76ae6

Browse files
authored
Fix #810
- Update the docstring of `predict_torch` Based on the conversation of #810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
1 parent dca509f commit 8a76ae6

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

segment_anything/predictor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,16 @@ def predict_torch(
202202
instead of a binary mask.
203203
204204
Returns:
205-
(torch.Tensor): The output masks in BxCxHxW format, where C is the
206-
number of masks, and (H, W) is the original image size.
205+
(torch.Tensor): The output masks in BxCxHxW format, where B is the
206+
number of batches, C is the number of masks per batch, and (H, W) is
207+
the original image size.
208+
The meaning of B depends on the prompt input.
207209
(torch.Tensor): An array of shape BxC containing the model's
208-
predictions for the quality of each mask.
209-
(torch.Tensor): An array of shape BxCxHxW, where C is the number
210-
of masks and H=W=256. These low res logits can be passed to
211-
a subsequent iteration as mask input.
210+
predictions for the quality of each mask per batch.
211+
(torch.Tensor): An array of shape BxCxHxW, where B is the
212+
number of batches, C is the number of masks per batch and H=W=256.
213+
These low res logits can be passed to a subsequent iteration as mask input.
214+
The meaning of B depends on the prompt input.
212215
"""
213216
if not self.is_image_set:
214217
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")

0 commit comments

Comments
 (0)