Skip to content

Commit 770cf37

Browse files
Jin Xucopybara-github
authored andcommitted
Migrate inputs.py
PiperOrigin-RevId: 646187131
1 parent 371a597 commit 770cf37

File tree

1 file changed

+326
-20
lines changed

1 file changed

+326
-20
lines changed

ffn/training/inputs.py

Lines changed: 326 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
# ==============================================================================
1515
"""Tensorflow Python ops and utilities for generating network inputs."""
1616

17+
import random
1718
import re
19+
from typing import Any, Callable, Optional, Sequence
1820

21+
from absl import logging
1922
from connectomics.common import bounding_box
23+
from connectomics.common import box_generator
24+
from connectomics.segmentation import labels as label_utils
25+
from ffn.training import augmentation
2026
import numpy as np
2127
import tensorflow.compat.v1 as tf
2228
from tensorflow.io import gfile
@@ -28,47 +34,203 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
2834
Args:
2935
coordinates_file_pattern: File pattern for TFRecords of
3036
input examples of the form of a glob
31-
pattern or path@shards.
37+
pattern or path@shards
38+
or Comma-separated file patterns.
3239
shuffle: Whether to shuffle the coordinate file list. Note that the expanded
3340
coordinates_file_pattern is not guaranteed to be sorted
3441
alphabetically.
3542
3643
Returns:
3744
Tensorflow queue with coordinate filenames
3845
"""
39-
m = re.search(r'@(\d{1,})', coordinates_file_pattern)
40-
if m:
41-
num_shards = int(m.group(1))
42-
coord_file_list = [
43-
re.sub(r'@(\d{1,})', '-%.5d-of-%.5d' % (i, num_shards),
44-
coordinates_file_pattern)
45-
for i in range(num_shards)]
46-
else:
47-
coord_file_list = gfile.glob(coordinates_file_pattern)
46+
coord_file_list = []
47+
for pattern in coordinates_file_pattern.split(','):
48+
m = re.search(r'@(\d{1,})', pattern)
49+
if m:
50+
num_shards = int(m.group(1))
51+
coord_file_list.extend([
52+
re.sub(
53+
r'@(\d{1,})',
54+
'-%.5d-of-%.5d' % (i, num_shards),
55+
pattern,
56+
)
57+
for i in range(num_shards)
58+
])
59+
else:
60+
coord_file_list.extend(gfile.glob(pattern))
4861
return tf.train.string_input_producer(coord_file_list, shuffle=shuffle)
4962

5063

51-
def load_patch_coordinates_from_filename_queue(filename_queue):
64+
def load_patch_coordinates_from_filename_queue(filename_queue,
65+
file_format='tfrecords'):
5266
"""Loads coordinates and volume names from filename queue.
5367
5468
Args:
5569
filename_queue: Tensorflow queue created from create_filename_queue()
70+
file_format: String indicating the format of the files in the queue.
71+
Can be 'sstables' or 'tfrecords'. Defaults to 'tfrecords'.
5672
5773
Returns:
5874
Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors.
5975
"""
60-
record_options = tf.python_io.TFRecordOptions(
61-
tf.python_io.TFRecordCompressionType.GZIP)
62-
keys, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
63-
examples = tf.parse_single_example(protos, features=dict(
64-
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
65-
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
66-
))
67-
coord = examples['center']
68-
volname = examples['label_volume_name']
76+
if file_format == 'tfrecords':
77+
record_options = tf.python_io.TFRecordOptions(
78+
tf.python_io.TFRecordCompressionType.GZIP)
79+
_, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
80+
examples = tf.parse_single_example(protos, features=dict(
81+
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
82+
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
83+
))
84+
coord = examples['center']
85+
volname = examples['label_volume_name']
86+
else:
87+
raise ValueError(f'Unsupported file format: {file_format}.')
88+
6989
return coord, volname
7090

7191

92+
def sample_patch_coordinates(
93+
bboxes: Sequence[Sequence[bounding_box.BoundingBox]],
94+
volinfo_map_string: str,
95+
name='sample_patch_coordinates',
96+
rng_seed: Optional[int] = None,
97+
) -> tf.data.Dataset:
98+
"""Samples a coordinate uniformly at random from specified bboxes.
99+
100+
Args:
101+
bboxes: sequence of sequences for bounding boxes (one seq. per volume)
102+
volinfo_map_string: comma delimited string mapping volname:volinfo_path,
103+
where volinfo_path is a gfile with text_format VolumeInfo proto for the
104+
volume from which patches should be extracted.
105+
name: passed to `name_scope`
106+
rng_seed: Random number generator seed allowing to make the dataset
107+
deterministic.
108+
109+
Returns:
110+
tuple of:
111+
[1, 3] int64 xyz coord tensor
112+
[1] string tensor with the volume label
113+
114+
Raises:
115+
ValueError: if len(bboxes) != len(volinfo_map) or if an invalid bbox is
116+
passed
117+
"""
118+
volinfo_pairs = volinfo_map_string.split(',')
119+
if len(bboxes) != len(volinfo_pairs):
120+
raise ValueError(
121+
'Numbers of bounding boxes and volume paths do not match.'
122+
)
123+
124+
volumes, flat_boxes = [], []
125+
total_voxels = 0
126+
for vol_id, volume_boxes in enumerate(bboxes):
127+
for b in volume_boxes:
128+
w = np.prod(b.size)
129+
if w < 0:
130+
raise ValueError('Volume %d, bbox %r is too small.' % (vol_id, b))
131+
total_voxels += w
132+
flat_boxes.append(b)
133+
volumes.append(vol_id)
134+
135+
calc = box_generator.MultiBoxGenerator(
136+
flat_boxes, box_size=(1, 1, 1), box_overlap=(0, 0, 0)
137+
)
138+
volnames = [v.split(':')[0] for v in volinfo_pairs]
139+
140+
def _sample_volinfo_and_bbox(idx):
141+
idx = idx[0]
142+
vol_idx = volumes[calc.index_to_generator_index(idx)[0]]
143+
_, coord_bbox = calc.generate(idx)
144+
assert coord_bbox is not None
145+
logging.info(
146+
'Sampled location %r from volume %s',
147+
coord_bbox.start,
148+
volnames[vol_idx],
149+
)
150+
coord = np.array([coord_bbox.start]).astype(np.int64)
151+
return coord, volnames[vol_idx]
152+
153+
def _sample(rng_seed):
154+
with tf.name_scope(name=name):
155+
coord, label = tf.py_func(
156+
_sample_volinfo_and_bbox,
157+
[
158+
tf.random.stateless_uniform(
159+
[1],
160+
rng_seed,
161+
maxval=total_voxels,
162+
dtype=tf.int64,
163+
name='rand',
164+
)
165+
],
166+
[tf.int64, tf.string],
167+
name='sample_volinfo_and_bbox',
168+
stateful=False,
169+
)
170+
label.set_shape([])
171+
coord.set_shape([1, 3])
172+
return {'coord': coord, 'volname': tf.reshape(label, [1])}
173+
174+
# This is faster than calling _sample_volinfo_and_bbox via .from_generator.
175+
return tf.data.Dataset.random(seed=rng_seed).batch(2).map(_sample)
176+
177+
178+
def get_vol_map(volinfo_paths: Sequence[str]):
179+
return ','.join(
180+
'vol%d:%s' % (i, volinfo) for i, volinfo in enumerate(volinfo_paths)
181+
)
182+
183+
184+
def parse_tf_coords(x):
185+
return tf.io.parse_single_example(
186+
x,
187+
features=dict(
188+
coord=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
189+
volname=tf.FixedLenFeature(shape=[1], dtype=tf.string),
190+
label=tf.FixedLenFeature(shape=[1], dtype=tf.int64),
191+
segment_id=tf.FixedLenFeature(
192+
shape=[1],
193+
dtype=tf.int64,
194+
default_value=tf.constant([0], dtype=tf.int64),
195+
),
196+
radius=tf.FixedLenFeature(
197+
shape=[1],
198+
dtype=tf.float32,
199+
default_value=tf.constant([0], dtype=tf.float32),
200+
),
201+
),
202+
)
203+
204+
205+
def load_coordinates_from_tfex(
206+
coord_pattern: str,
207+
shuffle: bool = True,
208+
shuffle_size: Optional[int] = 4096,
209+
shuffle_seed: Optional[int] = None,
210+
parse_fn: Callable[[Any], dict[str, Any]] = parse_tf_coords,
211+
reshuffle_each_iteration: bool = True,
212+
) -> tf.data.Dataset:
213+
"""Loads coordinates from a RecordIO of tf.Example protos."""
214+
coord_paths = sorted(gfile.Glob(coord_pattern))
215+
if shuffle:
216+
if shuffle_seed:
217+
random.Random(shuffle_seed).shuffle(coord_paths)
218+
else:
219+
random.shuffle(coord_paths)
220+
logging.info('Loading data from: %r', coord_paths)
221+
ds = tf.data.RecordIODataset(tf.constant(coord_paths, dtype=tf.string))
222+
223+
ds = ds.map(parse_fn, deterministic=True)
224+
if shuffle:
225+
ds = ds.shuffle(
226+
shuffle_size,
227+
seed=shuffle_seed,
228+
reshuffle_each_iteration=reshuffle_each_iteration,
229+
)
230+
231+
return ds.repeat()
232+
233+
72234
def load_patch_coordinates(coordinates_file_pattern,
73235
shuffle=True,
74236
scope='load_patch_coordinates'):
@@ -187,6 +349,7 @@ def get_offset_scale(volname,
187349
"""
188350

189351
def _get_offset_scale(volname):
352+
volname = volname.decode('utf-8')
190353
if volname in offset_scale_map:
191354
offset, scale = offset_scale_map[volname]
192355
else:
@@ -356,3 +519,146 @@ def soften_labels(bool_labels, softness=0.05, scope='soften_labels'):
356519
return tf.where(bool_labels,
357520
tf.fill(label_shape, 1.0 - softness, name='soft_true'),
358521
tf.fill(label_shape, softness, name='soft_false'))
522+
523+
524+
def make_labels_contiguous(labels: tf.Tensor) -> tf.Operation:
525+
"""Maps the labels to [0..N].
526+
527+
Args:
528+
labels: [1, z, y, x, 1] int64 tensor of labels
529+
530+
Returns:
531+
labels mapped to the range [0..N] if N distinct non-zero values are
532+
present in the input tensor
533+
"""
534+
ret = tf.py_func(
535+
label_utils.make_contiguous,
536+
inp=[labels],
537+
Tout=tf.int64,
538+
name='make_labels_contiguous',
539+
)
540+
ret.set_shape(labels.shape)
541+
return ret
542+
543+
544+
def apply_augmentation(
545+
data: dict[str, Any],
546+
section_augment: bool,
547+
section_augmentation_args: Optional[dict[str, Any]],
548+
permute_and_reflect_augment: bool,
549+
permutable_axes: list[int],
550+
reflectable_axes: list[int],
551+
rotation_augmentation: Optional[str],
552+
voxel_size: Optional[tuple[float, float, float]],
553+
) -> dict[str, Any]:
554+
"""Applies augmentations to a subvolume of data and corresponding labels.
555+
556+
Args:
557+
data: dict containing at least 'labels' and 'patches' tensors
558+
section_augment: whether to apply section augmentations
559+
section_augmentation_args: kwargs for
560+
augmentation.apply_section_augmentations
561+
permute_and_reflect_augment: whether to apply permutation/reflection
562+
permutable_axes: list of axes to permute
563+
reflectable_axes: list of axes to reflect
564+
rotation_augmentation: type of rotation augmenation to perform ('2d', '3d')
565+
voxel_size: xyz voxel size of the input data (only needed when applying
566+
rotation augmentation
567+
568+
Returns:
569+
'data' dict with 'labels' and 'patches' entries updated according to the
570+
chosen augmentations
571+
"""
572+
labels = data['labels']
573+
patches = data['patches']
574+
575+
# Apply section-wise augmentations.
576+
if section_augment:
577+
final_data_zyx = patches.shape_as_list()[1:4]
578+
final_label_zyx = labels.shape_as_list()[1:4]
579+
patches, labels, _ = augmentation.apply_section_augmentations(
580+
patches,
581+
labels,
582+
labels,
583+
final_data_zyx,
584+
final_label_zyx,
585+
final_label_zyx,
586+
**section_augmentation_args,
587+
)
588+
589+
# Apply basic augmentations.
590+
if permute_and_reflect_augment:
591+
transform_axes = augmentation.PermuteAndReflect(
592+
rank=5,
593+
permutable_axes=permutable_axes,
594+
reflectable_axes=reflectable_axes,
595+
)
596+
labels = transform_axes(labels)
597+
patches = transform_axes(patches)
598+
599+
rot_mtx = None
600+
if rotation_augmentation == '2d':
601+
rot_mtx = augmentation.random_2d_rotation_matrix()
602+
elif rotation_augmentation == '3d':
603+
rot_mtx = augmentation.random_3d_rotation_matrix()
604+
605+
if rot_mtx is not None:
606+
if labels.dtype == tf.int64:
607+
labels = tf.cond(
608+
tf.reduce_any(labels > np.iinfo(np.int32).max), #
609+
lambda: make_labels_contiguous(labels), #
610+
lambda: labels,
611+
)
612+
labels = tf.cast(labels, tf.int32)
613+
614+
assert voxel_size is not None
615+
patches = augmentation.apply_rotation(patches, rot_mtx, voxel_size)
616+
if labels.shape.as_list() != [1, 1, 1, 1, 1]:
617+
labels = augmentation.apply_rotation(labels, rot_mtx, voxel_size)
618+
619+
data['labels'] = labels
620+
data['patches'] = patches
621+
return data
622+
623+
624+
def interleave(datasets: Sequence[tf.data.Dataset], repeat=True):
625+
"""Interleave two or more datasets together, one at a time.
626+
627+
Interleaves two independently generated datasets together, contrary to
628+
Dataset.interleave which interleaves new Datasets generated from each input
629+
item.
630+
631+
Args:
632+
datasets: Sequence of datasets to interleave.
633+
repeat: repeat the interleaved sequence.
634+
635+
Returns:
636+
tf.data.Dataset with interleaved results.
637+
"""
638+
choice_dataset = tf.data.Dataset.range(len(datasets))
639+
if repeat:
640+
choice_dataset = choice_dataset.repeat()
641+
return tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
642+
643+
644+
def sample(
645+
datasets: Sequence[tf.data.Dataset],
646+
repeat=True,
647+
weights: Optional[Sequence[float]] = None,
648+
):
649+
"""Weighted sample of two or more datasets.
650+
651+
Args:
652+
datasets: Sequence of datasets to sample.
653+
repeat: repeat the sampled sequence.
654+
weights: relative weight of each respective dataset.
655+
656+
Returns:
657+
tf.data.Dataset with sampled results.
658+
"""
659+
if weights is None:
660+
weights = [1.0] * len(datasets)
661+
sampled_dataset = tf.data.experimental.sample_from_datasets(datasets, weights)
662+
if repeat:
663+
sampled_dataset = sampled_dataset.repeat()
664+
return sampled_dataset

0 commit comments

Comments
 (0)