1414# ==============================================================================
1515"""Tensorflow Python ops and utilities for generating network inputs."""
1616
17+ import random
1718import re
19+ from typing import Any , Callable , Optional , Sequence
1820
21+ from absl import logging
1922from connectomics .common import bounding_box
23+ from connectomics .common import box_generator
24+ from connectomics .segmentation import labels as label_utils
25+ from ffn .training import augmentation
2026import numpy as np
2127import tensorflow .compat .v1 as tf
2228from tensorflow .io import gfile
@@ -28,47 +34,203 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
2834 Args:
2935 coordinates_file_pattern: File pattern for TFRecords of
3036 input examples of the form of a glob
31- pattern or path@shards.
37+ pattern or path@shards
38+ or Comma-separated file patterns.
3239 shuffle: Whether to shuffle the coordinate file list. Note that the expanded
3340 coordinates_file_pattern is not guaranteed to be sorted
3441 alphabetically.
3542
3643 Returns:
3744 Tensorflow queue with coordinate filenames
3845 """
39- m = re .search (r'@(\d{1,})' , coordinates_file_pattern )
40- if m :
41- num_shards = int (m .group (1 ))
42- coord_file_list = [
43- re .sub (r'@(\d{1,})' , '-%.5d-of-%.5d' % (i , num_shards ),
44- coordinates_file_pattern )
45- for i in range (num_shards )]
46- else :
47- coord_file_list = gfile .glob (coordinates_file_pattern )
46+ coord_file_list = []
47+ for pattern in coordinates_file_pattern .split (',' ):
48+ m = re .search (r'@(\d{1,})' , pattern )
49+ if m :
50+ num_shards = int (m .group (1 ))
51+ coord_file_list .extend ([
52+ re .sub (
53+ r'@(\d{1,})' ,
54+ '-%.5d-of-%.5d' % (i , num_shards ),
55+ pattern ,
56+ )
57+ for i in range (num_shards )
58+ ])
59+ else :
60+ coord_file_list .extend (gfile .glob (pattern ))
4861 return tf .train .string_input_producer (coord_file_list , shuffle = shuffle )
4962
5063
51- def load_patch_coordinates_from_filename_queue (filename_queue ):
64+ def load_patch_coordinates_from_filename_queue (filename_queue ,
65+ file_format = 'tfrecords' ):
5266 """Loads coordinates and volume names from filename queue.
5367
5468 Args:
5569 filename_queue: Tensorflow queue created from create_filename_queue()
70+ file_format: String indicating the format of the files in the queue.
71+ Can be 'sstables' or 'tfrecords'. Defaults to 'tfrecords'.
5672
5773 Returns:
5874 Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors.
5975 """
60- record_options = tf .python_io .TFRecordOptions (
61- tf .python_io .TFRecordCompressionType .GZIP )
62- keys , protos = tf .TFRecordReader (options = record_options ).read (filename_queue )
63- examples = tf .parse_single_example (protos , features = dict (
64- center = tf .FixedLenFeature (shape = [1 , 3 ], dtype = tf .int64 ),
65- label_volume_name = tf .FixedLenFeature (shape = [1 ], dtype = tf .string ),
66- ))
67- coord = examples ['center' ]
68- volname = examples ['label_volume_name' ]
76+ if file_format == 'tfrecords' :
77+ record_options = tf .python_io .TFRecordOptions (
78+ tf .python_io .TFRecordCompressionType .GZIP )
79+ _ , protos = tf .TFRecordReader (options = record_options ).read (filename_queue )
80+ examples = tf .parse_single_example (protos , features = dict (
81+ center = tf .FixedLenFeature (shape = [1 , 3 ], dtype = tf .int64 ),
82+ label_volume_name = tf .FixedLenFeature (shape = [1 ], dtype = tf .string ),
83+ ))
84+ coord = examples ['center' ]
85+ volname = examples ['label_volume_name' ]
86+ else :
87+ raise ValueError (f'Unsupported file format: { file_format } .' )
88+
6989 return coord , volname
7090
7191
92+ def sample_patch_coordinates (
93+ bboxes : Sequence [Sequence [bounding_box .BoundingBox ]],
94+ volinfo_map_string : str ,
95+ name = 'sample_patch_coordinates' ,
96+ rng_seed : Optional [int ] = None ,
97+ ) -> tf .data .Dataset :
98+ """Samples a coordinate uniformly at random from specified bboxes.
99+
100+ Args:
101+ bboxes: sequence of sequences for bounding boxes (one seq. per volume)
102+ volinfo_map_string: comma delimited string mapping volname:volinfo_path,
103+ where volinfo_path is a gfile with text_format VolumeInfo proto for the
104+ volume from which patches should be extracted.
105+ name: passed to `name_scope`
106+ rng_seed: Random number generator seed allowing to make the dataset
107+ deterministic.
108+
109+ Returns:
110+ tuple of:
111+ [1, 3] int64 xyz coord tensor
112+ [1] string tensor with the volume label
113+
114+ Raises:
115+ ValueError: if len(bboxes) != len(volinfo_map) or if an invalid bbox is
116+ passed
117+ """
118+ volinfo_pairs = volinfo_map_string .split (',' )
119+ if len (bboxes ) != len (volinfo_pairs ):
120+ raise ValueError (
121+ 'Numbers of bounding boxes and volume paths do not match.'
122+ )
123+
124+ volumes , flat_boxes = [], []
125+ total_voxels = 0
126+ for vol_id , volume_boxes in enumerate (bboxes ):
127+ for b in volume_boxes :
128+ w = np .prod (b .size )
129+ if w < 0 :
130+ raise ValueError ('Volume %d, bbox %r is too small.' % (vol_id , b ))
131+ total_voxels += w
132+ flat_boxes .append (b )
133+ volumes .append (vol_id )
134+
135+ calc = box_generator .MultiBoxGenerator (
136+ flat_boxes , box_size = (1 , 1 , 1 ), box_overlap = (0 , 0 , 0 )
137+ )
138+ volnames = [v .split (':' )[0 ] for v in volinfo_pairs ]
139+
140+ def _sample_volinfo_and_bbox (idx ):
141+ idx = idx [0 ]
142+ vol_idx = volumes [calc .index_to_generator_index (idx )[0 ]]
143+ _ , coord_bbox = calc .generate (idx )
144+ assert coord_bbox is not None
145+ logging .info (
146+ 'Sampled location %r from volume %s' ,
147+ coord_bbox .start ,
148+ volnames [vol_idx ],
149+ )
150+ coord = np .array ([coord_bbox .start ]).astype (np .int64 )
151+ return coord , volnames [vol_idx ]
152+
153+ def _sample (rng_seed ):
154+ with tf .name_scope (name = name ):
155+ coord , label = tf .py_func (
156+ _sample_volinfo_and_bbox ,
157+ [
158+ tf .random .stateless_uniform (
159+ [1 ],
160+ rng_seed ,
161+ maxval = total_voxels ,
162+ dtype = tf .int64 ,
163+ name = 'rand' ,
164+ )
165+ ],
166+ [tf .int64 , tf .string ],
167+ name = 'sample_volinfo_and_bbox' ,
168+ stateful = False ,
169+ )
170+ label .set_shape ([])
171+ coord .set_shape ([1 , 3 ])
172+ return {'coord' : coord , 'volname' : tf .reshape (label , [1 ])}
173+
174+ # This is faster than calling _sample_volinfo_and_bbox via .from_generator.
175+ return tf .data .Dataset .random (seed = rng_seed ).batch (2 ).map (_sample )
176+
177+
178+ def get_vol_map (volinfo_paths : Sequence [str ]):
179+ return ',' .join (
180+ 'vol%d:%s' % (i , volinfo ) for i , volinfo in enumerate (volinfo_paths )
181+ )
182+
183+
184+ def parse_tf_coords (x ):
185+ return tf .io .parse_single_example (
186+ x ,
187+ features = dict (
188+ coord = tf .FixedLenFeature (shape = [1 , 3 ], dtype = tf .int64 ),
189+ volname = tf .FixedLenFeature (shape = [1 ], dtype = tf .string ),
190+ label = tf .FixedLenFeature (shape = [1 ], dtype = tf .int64 ),
191+ segment_id = tf .FixedLenFeature (
192+ shape = [1 ],
193+ dtype = tf .int64 ,
194+ default_value = tf .constant ([0 ], dtype = tf .int64 ),
195+ ),
196+ radius = tf .FixedLenFeature (
197+ shape = [1 ],
198+ dtype = tf .float32 ,
199+ default_value = tf .constant ([0 ], dtype = tf .float32 ),
200+ ),
201+ ),
202+ )
203+
204+
205+ def load_coordinates_from_tfex (
206+ coord_pattern : str ,
207+ shuffle : bool = True ,
208+ shuffle_size : Optional [int ] = 4096 ,
209+ shuffle_seed : Optional [int ] = None ,
210+ parse_fn : Callable [[Any ], dict [str , Any ]] = parse_tf_coords ,
211+ reshuffle_each_iteration : bool = True ,
212+ ) -> tf .data .Dataset :
213+ """Loads coordinates from a RecordIO of tf.Example protos."""
214+ coord_paths = sorted (gfile .Glob (coord_pattern ))
215+ if shuffle :
216+ if shuffle_seed :
217+ random .Random (shuffle_seed ).shuffle (coord_paths )
218+ else :
219+ random .shuffle (coord_paths )
220+ logging .info ('Loading data from: %r' , coord_paths )
221+ ds = tf .data .RecordIODataset (tf .constant (coord_paths , dtype = tf .string ))
222+
223+ ds = ds .map (parse_fn , deterministic = True )
224+ if shuffle :
225+ ds = ds .shuffle (
226+ shuffle_size ,
227+ seed = shuffle_seed ,
228+ reshuffle_each_iteration = reshuffle_each_iteration ,
229+ )
230+
231+ return ds .repeat ()
232+
233+
72234def load_patch_coordinates (coordinates_file_pattern ,
73235 shuffle = True ,
74236 scope = 'load_patch_coordinates' ):
@@ -187,6 +349,7 @@ def get_offset_scale(volname,
187349 """
188350
189351 def _get_offset_scale (volname ):
352+ volname = volname .decode ('utf-8' )
190353 if volname in offset_scale_map :
191354 offset , scale = offset_scale_map [volname ]
192355 else :
@@ -356,3 +519,146 @@ def soften_labels(bool_labels, softness=0.05, scope='soften_labels'):
356519 return tf .where (bool_labels ,
357520 tf .fill (label_shape , 1.0 - softness , name = 'soft_true' ),
358521 tf .fill (label_shape , softness , name = 'soft_false' ))
522+
523+
524+ def make_labels_contiguous (labels : tf .Tensor ) -> tf .Operation :
525+ """Maps the labels to [0..N].
526+
527+ Args:
528+ labels: [1, z, y, x, 1] int64 tensor of labels
529+
530+ Returns:
531+ labels mapped to the range [0..N] if N distinct non-zero values are
532+ present in the input tensor
533+ """
534+ ret = tf .py_func (
535+ label_utils .make_contiguous ,
536+ inp = [labels ],
537+ Tout = tf .int64 ,
538+ name = 'make_labels_contiguous' ,
539+ )
540+ ret .set_shape (labels .shape )
541+ return ret
542+
543+
544+ def apply_augmentation (
545+ data : dict [str , Any ],
546+ section_augment : bool ,
547+ section_augmentation_args : Optional [dict [str , Any ]],
548+ permute_and_reflect_augment : bool ,
549+ permutable_axes : list [int ],
550+ reflectable_axes : list [int ],
551+ rotation_augmentation : Optional [str ],
552+ voxel_size : Optional [tuple [float , float , float ]],
553+ ) -> dict [str , Any ]:
554+ """Applies augmentations to a subvolume of data and corresponding labels.
555+
556+ Args:
557+ data: dict containing at least 'labels' and 'patches' tensors
558+ section_augment: whether to apply section augmentations
559+ section_augmentation_args: kwargs for
560+ augmentation.apply_section_augmentations
561+ permute_and_reflect_augment: whether to apply permutation/reflection
562+ permutable_axes: list of axes to permute
563+ reflectable_axes: list of axes to reflect
564+ rotation_augmentation: type of rotation augmenation to perform ('2d', '3d')
565+ voxel_size: xyz voxel size of the input data (only needed when applying
566+ rotation augmentation
567+
568+ Returns:
569+ 'data' dict with 'labels' and 'patches' entries updated according to the
570+ chosen augmentations
571+ """
572+ labels = data ['labels' ]
573+ patches = data ['patches' ]
574+
575+ # Apply section-wise augmentations.
576+ if section_augment :
577+ final_data_zyx = patches .shape_as_list ()[1 :4 ]
578+ final_label_zyx = labels .shape_as_list ()[1 :4 ]
579+ patches , labels , _ = augmentation .apply_section_augmentations (
580+ patches ,
581+ labels ,
582+ labels ,
583+ final_data_zyx ,
584+ final_label_zyx ,
585+ final_label_zyx ,
586+ ** section_augmentation_args ,
587+ )
588+
589+ # Apply basic augmentations.
590+ if permute_and_reflect_augment :
591+ transform_axes = augmentation .PermuteAndReflect (
592+ rank = 5 ,
593+ permutable_axes = permutable_axes ,
594+ reflectable_axes = reflectable_axes ,
595+ )
596+ labels = transform_axes (labels )
597+ patches = transform_axes (patches )
598+
599+ rot_mtx = None
600+ if rotation_augmentation == '2d' :
601+ rot_mtx = augmentation .random_2d_rotation_matrix ()
602+ elif rotation_augmentation == '3d' :
603+ rot_mtx = augmentation .random_3d_rotation_matrix ()
604+
605+ if rot_mtx is not None :
606+ if labels .dtype == tf .int64 :
607+ labels = tf .cond (
608+ tf .reduce_any (labels > np .iinfo (np .int32 ).max ), #
609+ lambda : make_labels_contiguous (labels ), #
610+ lambda : labels ,
611+ )
612+ labels = tf .cast (labels , tf .int32 )
613+
614+ assert voxel_size is not None
615+ patches = augmentation .apply_rotation (patches , rot_mtx , voxel_size )
616+ if labels .shape .as_list () != [1 , 1 , 1 , 1 , 1 ]:
617+ labels = augmentation .apply_rotation (labels , rot_mtx , voxel_size )
618+
619+ data ['labels' ] = labels
620+ data ['patches' ] = patches
621+ return data
622+
623+
624+ def interleave (datasets : Sequence [tf .data .Dataset ], repeat = True ):
625+ """Interleave two or more datasets together, one at a time.
626+
627+ Interleaves two independently generated datasets together, contrary to
628+ Dataset.interleave which interleaves new Datasets generated from each input
629+ item.
630+
631+ Args:
632+ datasets: Sequence of datasets to interleave.
633+ repeat: repeat the interleaved sequence.
634+
635+ Returns:
636+ tf.data.Dataset with interleaved results.
637+ """
638+ choice_dataset = tf .data .Dataset .range (len (datasets ))
639+ if repeat :
640+ choice_dataset = choice_dataset .repeat ()
641+ return tf .data .experimental .choose_from_datasets (datasets , choice_dataset )
642+
643+
644+ def sample (
645+ datasets : Sequence [tf .data .Dataset ],
646+ repeat = True ,
647+ weights : Optional [Sequence [float ]] = None ,
648+ ):
649+ """Weighted sample of two or more datasets.
650+
651+ Args:
652+ datasets: Sequence of datasets to sample.
653+ repeat: repeat the sampled sequence.
654+ weights: relative weight of each respective dataset.
655+
656+ Returns:
657+ tf.data.Dataset with sampled results.
658+ """
659+ if weights is None :
660+ weights = [1.0 ] * len (datasets )
661+ sampled_dataset = tf .data .experimental .sample_from_datasets (datasets , weights )
662+ if repeat :
663+ sampled_dataset = sampled_dataset .repeat ()
664+ return sampled_dataset
0 commit comments