@@ -152,8 +152,8 @@ def __repr__(self):
152152 return self .__str__ ()
153153
154154 def __str__ (self ):
155- str_repr = 'ImageType[ name={}, shape={}, scale={}, bias={}, ' + \
156- 'color_layout={}, channel_first={}] '
155+ str_repr = 'ImageType( name={}, shape={}, scale={}, bias={}, ' + \
156+ 'color_layout={}, channel_first={}) '
157157 return str_repr .format (self .name , self .shape , self .scale , self .bias ,
158158 self .color_layout , self .channel_first )
159159
@@ -268,7 +268,7 @@ def __repr__(self):
268268 return self .__str__ ()
269269
270270 def __str__ (self ):
271- return 'TensorType[ name={}, shape={}, dtype={}] ' .format (self .name ,
271+ return 'TensorType( name={}, shape={}, dtype={}) ' .format (self .name ,
272272 self .shape ,
273273 self .dtype )
274274
@@ -408,6 +408,11 @@ def __init__(self, shape, default=None):
408408 def __str__ (self ):
409409 return str (self .shape )
410410
411+
412+ def __repr__ (self ):
413+ return self .__str__ ()
414+
415+
411416 @property
412417 def has_symbolic (self ):
413418 return any (is_symbolic (s ) for s in self .symbolic_shape )
@@ -436,7 +441,27 @@ def __init__(self, shapes, default=None):
436441 the metadata of the model file.
437442
438443 If None, then the first element in ``shapes`` is used.
444+
445+ Examples
446+ --------
447+ .. sourcecode:: python
448+
449+ sample_shape = ct.EnumeratedShapes(
450+ shapes=[
451+ (2, 4, 64, 64),
452+ (2, 4, 48, 48),
453+ (2, 4, 32, 32)
454+ ],
455+ default=(2, 4, 64, 64)
456+ )
457+
458+ my_core_ml_model = ct.convert(
459+ my_model,
460+ inputs=[ct.TensorType(name="sample", shape=sample_shape)],
461+ )
439462 """
463+
464+ # lazy import to avoid circular import
440465 from coremltools .converters .mil .mil import get_new_symbol
441466
442467 if not isinstance (shapes , (list , tuple )):
@@ -490,6 +515,14 @@ def __init__(self, shapes, default=None):
490515 self .default = default
491516
492517
518+ def __repr__ (self ):
519+ return self .__str__ ()
520+
521+
522+ def __str__ (self ):
523+ return "EnumeratedShapes(" + str (self .shapes ) + ", default=" + str (self .default ) + ")"
524+
525+
493526def _get_shaping_class (shape ):
494527 """
495528 Returns a Shape class or EnumeratedShapes class for `shape`
0 commit comments