8585from dash .dependencies import Input , Output , State , ALL
8686from dash_core_components import Graph , Slider , Store , Interval
8787
88- from .utils import img_array_to_uri , get_thumbnail_size , shape3d_to_size2d
88+ from .utils import (
89+ discrete_colors ,
90+ img_array_to_uri ,
91+ get_thumbnail_size ,
92+ shape3d_to_size2d ,
93+ mask_to_coloured_slices ,
94+ )
8995
9096
91- # The default colors to use for indicators and overlays
92- discrete_colors = plotly .colors .qualitative .D3
93-
9497_assigned_scene_ids = {} # id(volume) -> str
9598
9699
@@ -170,14 +173,14 @@ def __init__(
170173 elif isinstance (clim , (tuple , list )) and len (clim ) == 2 :
171174 self ._initial_clim = float (clim [0 ]), float (clim [1 ])
172175 else :
173- raise ValueError ("The clim must be None or a 2-tuple of floats." )
176+ raise TypeError ("The clim must be None or a 2-tuple of floats." )
174177
175178 # Check and store thumbnail
176179 if not (isinstance (thumbnail , (int , bool ))):
177- raise ValueError ("thumbnail must be a boolean or an integer." )
178- if thumbnail is False :
180+ raise TypeError ("thumbnail must be a boolean or an integer." )
181+ if not thumbnail :
179182 self ._thumbnail_param = None
180- elif thumbnail is None or thumbnail is True :
183+ elif thumbnail is True :
181184 self ._thumbnail_param = 32 # default size
182185 else :
183186 thumbnail = int (thumbnail )
@@ -214,15 +217,23 @@ def __init__(
214217 "offset" : shape3d_to_size2d (origin , axis ),
215218 "stepsize" : shape3d_to_size2d (spacing , axis ),
216219 "color" : color ,
220+ "infoid" : np .random .randint (1 , 9999999 ),
217221 }
218222
223+ # Also store thumbnail size. The get_thumbnail_size() is a bit like
224+ # a simulation to get the low-res size.
225+ if self ._thumbnail_param is None :
226+ self ._slice_info ["thumbnail_size" ] = self ._slice_info ["size" ][:2 ]
227+ else :
228+ self ._slice_info ["thumbnail_size" ] = get_thumbnail_size (
229+ self ._slice_info ["size" ][:2 ], self ._thumbnail_param
230+ )
231+
219232 # Build the slicer
220233 self ._create_dash_components ()
221234 self ._create_server_callbacks ()
222235 self ._create_client_callbacks ()
223236
224- # Note(AK): we could make some stores public, but let's do this only when actual use-cases arise?
225-
226237 @property
227238 def scene_id (self ) -> str :
228239 """The id of the "virtual scene" for this slicer. Slicers that have
@@ -258,14 +269,16 @@ def slider(self):
258269 @property
259270 def stores (self ):
260271 """A list of `dcc.Store` objects that the slicer needs to work.
261- These must be added to the app layout.
272+ These must be added to the app layout. Note that public stores
273+ like `state` and `extra_traces` are also present in this list.
262274 """
263275 return self ._stores
264276
265277 @property
266278 def state (self ):
267279 """A `dcc.Store` representing the current state of the slicer (present
268- in slicer.stores). Its data is a dict with the fields:
280+ in slicer.stores). This store is intended for use as State or Input.
281+ Its data is a dict with the fields:
269282
270283 * "index": the integer slice index.
271284 * "index_changed": a bool indicating whether the index changed since last time.
@@ -283,26 +296,26 @@ def state(self):
283296
284297 @property
285298 def clim (self ):
286- """A `dcc.Store` representing the contrast limits as a 2-element tuple.
287- This value should probably not be changed too often (e.g. on slider drag)
288- because the thumbnail data is recreated on each change.
299+ """A `dcc.Store` to be used as Output, representing the contrast
300+ limits as a 2-element tuple. This value should probably not be
301+ changed too often (e.g. on slider drag) because the thumbnail
302+ data is recreated on each change.
289303 """
290304 return self ._clim
291305
292306 @property
293307 def extra_traces (self ):
294- """A `dcc.Store` that can be used as an output to define
295- additional traces to be shown in this slicer. The data must be
296- a list of dictionaries, with each dict representing a raw trace
297- object.
308+ """A `dcc.Store` to be used as an Output to define additional
309+ traces to be shown in this slicer. The data must be a list of
310+ dictionaries, with each dict representing a raw trace object.
298311 """
299312 return self ._extra_traces
300313
301314 @property
302315 def overlay_data (self ):
303- """A `dcc.Store` containing the overlay data. The form of this
304- data is considered an implementation detail; users are expected to use
305- `create_overlay_data` to create it.
316+ """A `dcc.Store` to be used an Output for the overlay data. The
317+ form of this data is considered an implementation detail; users
318+ are expected to use `create_overlay_data` to create it.
306319 """
307320 return self ._overlay_data
308321
@@ -312,71 +325,13 @@ def create_overlay_data(self, mask, color=None):
312325 The color can be a hex color or an rgb/rgba tuple. Alternatively,
313326 color can be a list of such colors, defining a colormap.
314327 """
315- # Check the mask
316328 if mask is None :
317329 return [None for index in range (self .nslices )] # A reset
318- elif not isinstance (mask , np .ndarray ):
319- raise TypeError ("Mask must be an ndarray or None." )
320- elif mask .dtype not in (np .bool , np .uint8 ):
321- raise ValueError (f"Mask must have bool or uint8 dtype, not { mask .dtype } ." )
322330 elif mask .shape != self ._volume .shape :
323331 raise ValueError (
324332 f"Overlay must has shape { mask .shape } , but expected { self ._volume .shape } "
325333 )
326- mask = mask .astype (np .uint8 , copy = False ) # need int to index
327-
328- # Create a colormap (list) from the given color(s)
329- if color is None :
330- colormap = discrete_colors [3 :]
331- elif isinstance (color , str ):
332- colormap = [color ]
333- elif isinstance (color , (tuple , list )) and all (
334- isinstance (x , (int , float )) for x in color
335- ):
336- colormap = [color ]
337- else :
338- colormap = list (color )
339-
340- # Normalize the colormap so each element is a 4-element tuple
341- for i in range (len (colormap )):
342- c = colormap [i ]
343- if isinstance (c , str ):
344- if c .startswith ("#" ):
345- c = plotly .colors .hex_to_rgb (c )
346- else :
347- raise ValueError (
348- "Named colors are not (yet) supported, hex colors are."
349- )
350- c = tuple (int (x ) for x in c )
351- if len (c ) == 3 :
352- c = c + (100 ,)
353- elif len (c ) != 4 :
354- raise ValueError ("Expected color tuples to be 3 or 4 elements." )
355- colormap [i ] = c
356-
357- # Insert zero stub color for where mask is zero
358- colormap .insert (0 , (0 , 0 , 0 , 0 ))
359-
360- # Produce slices (base64 png strings)
361- overlay_slices = []
362- for index in range (self .nslices ):
363- # Sample the slice
364- indices = [slice (None ), slice (None ), slice (None )]
365- indices [self ._axis ] = index
366- im = mask [tuple (indices )]
367- max_mask = im .max ()
368- if max_mask == 0 :
369- # If the mask is all zeros, we can simply not draw it
370- overlay_slices .append (None )
371- else :
372- # Turn into rgba
373- while len (colormap ) <= max_mask :
374- colormap .append (colormap [- 1 ])
375- colormap_arr = np .array (colormap )
376- rgba = colormap_arr [im ]
377- overlay_slices .append (img_array_to_uri (rgba ))
378-
379- return overlay_slices
334+ return mask_to_coloured_slices (mask , self ._axis , color )
380335
381336 def _subid (self , name , use_dict = False , ** kwargs ):
382337 """Given a name, get the full id including the context id prefix."""
@@ -412,15 +367,6 @@ def _create_dash_components(self):
412367 """Create the graph, slider, figure, etc."""
413368 info = self ._slice_info
414369
415- # Prep low-res slices. The get_thumbnail_size() is a bit like
416- # a simulation to get the low-res size.
417- if self ._thumbnail_param is None :
418- info ["thumbnail_size" ] = info ["size" ]
419- else :
420- info ["thumbnail_size" ] = get_thumbnail_size (
421- info ["size" ][:2 ], self ._thumbnail_param
422- )
423-
424370 # Create the figure object - can be accessed by user via slicer.graph.figure
425371 self ._fig = fig = plotly .graph_objects .Figure (data = [])
426372 fig .update_layout (
@@ -469,10 +415,10 @@ def _create_dash_components(self):
469415 # A dict of static info for this slicer
470416 self ._info = Store (id = self ._subid ("info" ), data = info )
471417
472- # A list of contrast limits
418+ # A tuple representing the contrast limits
473419 self ._clim = Store (id = self ._subid ("clim" ), data = self ._initial_clim )
474420
475- # A list of low-res slices , or the full-res data ( encoded as base64-png)
421+ # A list of thumbnails ( low-res, or the full-re, encoded as base64-png)
476422 self ._thumbs_data = Store (id = self ._subid ("thumbs" ), data = [])
477423
478424 # A list of mask slices (encoded as base64-png or null)
@@ -483,13 +429,13 @@ def _create_dash_components(self):
483429 id = self ._subid ("server-data" ), data = {"index" : - 1 , "slice" : None }
484430 )
485431
486- # Store image traces for the slicer.
432+ # Store image traces to show in the figure
487433 self ._img_traces = Store (id = self ._subid ("img-traces" ), data = [])
488434
489- # Store indicator traces for the slicer.
435+ # Store indicator traces to show in the figure
490436 self ._indicator_traces = Store (id = self ._subid ("indicator-traces" ), data = [])
491437
492- # Store user traces for the slider.
438+ # Store more ( user-defined) traces to show in the figure
493439 self ._extra_traces = Store (id = self ._subid ("extra-traces" ), data = [])
494440
495441 # A timer to apply a rate-limit between slider.value and index.data
@@ -554,12 +500,17 @@ def _create_client_callbacks(self):
554500 # \ server_data (a new slice)
555501 # \ \
556502 # \ --> image_traces
557- # ----------------------- / \
558- # -----> figure
503+ # ------------------------/ \
504+ # \
505+ # state (external) --> indicator_traces -- -----> figure
559506 # /
560- # indicator_traces
561- # /
562- # state (external)
507+ # extra_traces
508+ #
509+ # This figure is incomplete, for the sake of keeping it
510+ # relatively simple. E.g. the thumbnail data is also an input
511+ # for the callback that generates the image traces. And the
512+ # clim store is an input for the callbacks that produce
513+ # server_data and thumbnail data.
563514
564515 app = self ._app
565516
@@ -667,6 +618,7 @@ def _create_client_callbacks(self):
667618 Input (self ._graph .id , "relayoutData" ),
668619 Input (self ._timer .id , "n_intervals" ),
669620 ],
621+ prevent_initial_call = True ,
670622 )
671623
672624 # ----------------------------------------------------------------------
@@ -687,6 +639,10 @@ def _create_client_callbacks(self):
687639 if (!(private_state.timeout && now >= private_state.timeout)) {
688640 return dash_clientside.no_update;
689641 }
642+ // Give the plot time to settle the initial axis ranges
643+ if (n_intervals < 5) {
644+ return dash_clientside.no_update;
645+ }
690646
691647 // Disable the timer
692648 private_state.timeout = 0;
@@ -732,10 +688,11 @@ def _create_client_callbacks(self):
732688 axis: info.axis,
733689 color: info.color,
734690 };
735- if (index != private_state.last_index) {
691+ if (index != private_state.last_index || info.infoid != private_state.infoid ) {
736692 private_state.last_index = index;
737693 new_state.index_changed = true;
738694 }
695+ private_state.infoid = info.infoid; // infoid changes on hot reload
739696 return new_state;
740697 }
741698 """ .replace (
0 commit comments