Skip to content

Commit 5f60abd

Browse files
authored
Minor enhancements (#2026)
1 parent f0597ba commit 5f60abd

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

coremltools/converters/mil/input_types.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __repr__(self):
152152
return self.__str__()
153153

154154
def __str__(self):
155-
str_repr = 'ImageType[name={}, shape={}, scale={}, bias={}, ' +\
156-
'color_layout={}, channel_first={}]'
155+
str_repr = 'ImageType(name={}, shape={}, scale={}, bias={}, ' +\
156+
'color_layout={}, channel_first={})'
157157
return str_repr.format(self.name, self.shape, self.scale, self.bias,
158158
self.color_layout, self.channel_first)
159159

@@ -268,7 +268,7 @@ def __repr__(self):
268268
return self.__str__()
269269

270270
def __str__(self):
271-
return 'TensorType[name={}, shape={}, dtype={}]'.format(self.name,
271+
return 'TensorType(name={}, shape={}, dtype={})'.format(self.name,
272272
self.shape,
273273
self.dtype)
274274

@@ -408,6 +408,11 @@ def __init__(self, shape, default=None):
408408
def __str__(self):
409409
return str(self.shape)
410410

411+
412+
def __repr__(self):
413+
return self.__str__()
414+
415+
411416
@property
412417
def has_symbolic(self):
413418
return any(is_symbolic(s) for s in self.symbolic_shape)
@@ -436,7 +441,27 @@ def __init__(self, shapes, default=None):
436441
the metadata of the model file.
437442
438443
If None, then the first element in ``shapes`` is used.
444+
445+
Examples
446+
--------
447+
.. sourcecode:: python
448+
449+
sample_shape = ct.EnumeratedShapes(
450+
shapes=[
451+
(2, 4, 64, 64),
452+
(2, 4, 48, 48),
453+
(2, 4, 32, 32)
454+
],
455+
default=(2, 4, 64, 64)
456+
)
457+
458+
my_core_ml_model = ct.convert(
459+
my_model,
460+
inputs=[ct.TensorType(name="sample", shape=sample_shape)],
461+
)
439462
"""
463+
464+
# lazy import to avoid circular import
440465
from coremltools.converters.mil.mil import get_new_symbol
441466

442467
if not isinstance(shapes, (list, tuple)):
@@ -490,6 +515,14 @@ def __init__(self, shapes, default=None):
490515
self.default = default
491516

492517

518+
def __repr__(self):
519+
return self.__str__()
520+
521+
522+
def __str__(self):
523+
return "EnumeratedShapes(" + str(self.shapes) + ", default=" + str(self.default) + ")"
524+
525+
493526
def _get_shaping_class(shape):
494527
"""
495528
Returns a Shape class or EnumeratedShapes class for `shape`

coremltools/converters/mil/mil/ops/defs/iOS15/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def type_inference(self):
291291
x_shape = x.shape
292292
y_shape = y.shape
293293
assert len(x_shape) == len(y_shape), "inputs not of the same rank"
294-
assert x_shape[-1] == y_shape[-3], "input shapes incompatible"
294+
if not (is_symbolic(x_shape[-1]) or is_symbolic(y_shape[-3])):
295+
assert x_shape[-1] == y_shape[-3], f"input shapes incompatible: {x_shape[-1]} and {y_shape[-3]}"
295296
if x_shape[-2] != 1 and y_shape[-2] != 1:
296297
assert x_shape[-2] == y_shape[-2], "input shapes incompatible"
297298
if len(x_shape) == 4:

0 commit comments

Comments
 (0)