1313# limitations under the License.
1414import logging
1515import textwrap
16- from typing import Dict , List , Optional , Tuple , Union
16+ from typing import ClassVar , Dict , List , Optional , Tuple , Union
1717
1818from pydantic import BaseModel , Field
1919
3030]
3131
3232_LOGGER = logging .getLogger (__name__ )
33+ PrintOrderType = ClassVar [List [str ]]
3334
3435
3536class PropertyBaseModel (BaseModel ):
@@ -104,11 +105,12 @@ class NodeIO(BaseModel):
104105
105106 name : str = Field (description = "Name of the input/output in onnx model graph" )
106107 shape : Optional [List [Union [None , int ]]] = Field (
108+ None ,
107109 description = "Shape of the input/output in onnx model graph (assuming a "
108- "batch size of 1)"
110+ "batch size of 1)" ,
109111 )
110112 dtype : Optional [str ] = Field (
111- description = "Data type of the values from the input/output"
113+ None , description = "Data type of the values from the input/output"
112114 )
113115
114116
@@ -220,9 +222,9 @@ class ParameterComponent(BaseModel):
220222 """
221223
222224 alias : str = Field (description = "The type of parameter (weight, bias)" )
223- name : Optional [str ] = Field (description = "The name of the parameter" )
225+ name : Optional [str ] = Field (None , description = "The name of the parameter" )
224226 shape : Optional [List [Union [None , int ]]] = Field (
225- description = "The shape of the parameter"
227+ None , description = "The shape of the parameter"
226228 )
227229 parameter_summary : ParameterSummary = Field (
228230 description = "A summary of the parameter"
@@ -235,7 +237,7 @@ class Entry(BaseModel):
235237 A BaseModel with subtraction and pretty_print support
236238 """
237239
238- _print_order : List [ str ] = []
240+ _print_order : PrintOrderType = []
239241
240242 def __sub__ (self , other ):
241243 """
@@ -306,7 +308,7 @@ class BaseEntry(Entry):
306308 sparsity : float
307309 quantized : float
308310
309- _print_order = ["sparsity" , "quantized" ]
311+ _print_order : PrintOrderType = ["sparsity" , "quantized" ]
310312
311313
312314class NamedEntry (BaseEntry ):
@@ -318,7 +320,7 @@ class NamedEntry(BaseEntry):
318320 total : float
319321 size : int
320322
321- _print_order = ["name" , "total" , "size" ] + BaseEntry ._print_order
323+ _print_order : PrintOrderType = ["name" , "total" , "size" ] + BaseEntry ._print_order
322324
323325
324326class TypedEntry (BaseEntry ):
@@ -329,7 +331,7 @@ class TypedEntry(BaseEntry):
329331 type : str
330332 size : int
331333
332- _print_order = ["type" , "size" ] + BaseEntry ._print_order
334+ _print_order : PrintOrderType = ["type" , "size" ] + BaseEntry ._print_order
333335
334336
335337class ModelEntry (BaseEntry ):
@@ -338,7 +340,7 @@ class ModelEntry(BaseEntry):
338340 """
339341
340342 model : str
341- _print_order = ["model" ] + BaseEntry ._print_order
343+ _print_order : PrintOrderType = ["model" ] + BaseEntry ._print_order
342344
343345
344346class SizedModelEntry (ModelEntry ):
@@ -347,8 +349,8 @@ class SizedModelEntry(ModelEntry):
347349 """
348350
349351 count : int
350- size : int
351- _print_order = ModelEntry ._print_order + ["count" , "size" ]
352+ size : Union [ int , float ]
353+ _print_order : PrintOrderType = ModelEntry ._print_order + ["count" , "size" ]
352354
353355
354356class PerformanceEntry (BaseEntry ):
@@ -361,7 +363,7 @@ class PerformanceEntry(BaseEntry):
361363 throughput : float
362364 supported_graph : float
363365
364- _print_order = [
366+ _print_order : PrintOrderType = [
365367 "model" ,
366368 "latency" ,
367369 "throughput" ,
@@ -377,7 +379,7 @@ class NodeTimingEntry(Entry):
377379 node_name : str
378380 avg_runtime : float
379381
380- _print_order = [
382+ _print_order : PrintOrderType = [
381383 "node_name" ,
382384 "avg_runtime" ,
383385 ] + Entry ._print_order
0 commit comments