Skip to content

Commit ec4d407

Browse files
rchen152mn-robot
authored andcommitted
Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376873484
1 parent f41c85c commit ec4d407

File tree

5 files changed

+98
-96
lines changed

5 files changed

+98
-96
lines changed

morph_net/network_regularizers/activation_regularizer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# [internal] enable type annotations
66
from __future__ import print_function
77

8+
from typing import Optional
9+
810
from morph_net.framework import batch_norm_source_op_handler
911
from morph_net.framework import conv2d_transpose_source_op_handler
1012
from morph_net.framework import conv_source_op_handler
@@ -22,15 +24,15 @@
2224
class GammaActivationRegularizer(generic_regularizers.NetworkRegularizer):
2325
"""A NetworkRegularizer that targets activation count using Gamma L1."""
2426

25-
def __init__(
26-
self,
27-
output_boundary: List[tf.Operation],
28-
gamma_threshold,
29-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
30-
decorator_parameters=None,
31-
input_boundary: List[tf.Operation] = None,
32-
force_group=None,
33-
regularizer_blacklist=None):
27+
def __init__(self,
28+
output_boundary: List[tf.Operation],
29+
gamma_threshold,
30+
regularizer_decorator: Optional[Type[
31+
generic_regularizers.OpRegularizer]] = None,
32+
decorator_parameters=None,
33+
input_boundary: Optional[List[tf.Operation]] = None,
34+
force_group=None,
35+
regularizer_blacklist=None):
3436
"""Creates a GammaActivationRegularizer object.
3537
3638
Args:
@@ -95,16 +97,16 @@ def cost_name(self):
9597
class GroupLassoActivationRegularizer(generic_regularizers.NetworkRegularizer):
9698
"""A NetworkRegularizer that targets activation count using L1 group lasso."""
9799

98-
def __init__(
99-
self,
100-
output_boundary: List[tf.Operation],
101-
threshold,
102-
l1_fraction=0,
103-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
104-
decorator_parameters=None,
105-
input_boundary: List[tf.Operation] = None,
106-
force_group=None,
107-
regularizer_blacklist=None):
100+
def __init__(self,
101+
output_boundary: List[tf.Operation],
102+
threshold,
103+
l1_fraction=0,
104+
regularizer_decorator: Optional[Type[
105+
generic_regularizers.OpRegularizer]] = None,
106+
decorator_parameters=None,
107+
input_boundary: Optional[List[tf.Operation]] = None,
108+
force_group=None,
109+
regularizer_blacklist=None):
108110
"""Creates a GroupLassoActivationRegularizer object.
109111
110112
Args:

morph_net/network_regularizers/flop_regularizer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import division
55
# [internal] enable type annotations
66
from __future__ import print_function
7-
from typing import Type, List
7+
from typing import List, Optional, Type
88

99
from morph_net.framework import batch_norm_source_op_handler
1010
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
@@ -40,15 +40,15 @@ def cost_name(self):
4040
class GammaFlopsRegularizer(generic_regularizers.NetworkRegularizer):
4141
"""A NetworkRegularizer that targets FLOPs using Gamma L1 as OpRegularizer."""
4242

43-
def __init__(
44-
self,
45-
output_boundary: List[tf.Operation],
46-
gamma_threshold,
47-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
48-
decorator_parameters=None,
49-
input_boundary: List[tf.Operation] = None,
50-
force_group=None,
51-
regularizer_blacklist=None):
43+
def __init__(self,
44+
output_boundary: List[tf.Operation],
45+
gamma_threshold,
46+
regularizer_decorator: Optional[Type[
47+
generic_regularizers.OpRegularizer]] = None,
48+
decorator_parameters=None,
49+
input_boundary: Optional[List[tf.Operation]] = None,
50+
force_group=None,
51+
regularizer_blacklist=None):
5252
"""Creates a GammaFlopsRegularizer object.
5353
5454
Args:
@@ -113,16 +113,16 @@ def cost_name(self):
113113
class GroupLassoFlopsRegularizer(generic_regularizers.NetworkRegularizer):
114114
"""A NetworkRegularizer that targets FLOPs using L1 group lasso."""
115115

116-
def __init__(
117-
self,
118-
output_boundary: List[tf.Operation],
119-
threshold,
120-
l1_fraction=0,
121-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
122-
decorator_parameters=None,
123-
input_boundary: List[tf.Operation] = None,
124-
force_group=None,
125-
regularizer_blacklist=None):
116+
def __init__(self,
117+
output_boundary: List[tf.Operation],
118+
threshold,
119+
l1_fraction=0,
120+
regularizer_decorator: Optional[Type[
121+
generic_regularizers.OpRegularizer]] = None,
122+
decorator_parameters=None,
123+
input_boundary: Optional[List[tf.Operation]] = None,
124+
force_group=None,
125+
regularizer_blacklist=None):
126126
"""Creates a GroupLassoFlopsRegularizer object.
127127
128128
Args:

morph_net/network_regularizers/latency_regularizer.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""A NetworkRegularizer that targets inference latency."""
22

3-
from typing import Type, List
3+
from typing import List, Optional, Type
44

55
from morph_net.framework import batch_norm_source_op_handler
66
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
@@ -52,19 +52,19 @@ class LogisticSigmoidLatencyRegularizer(
5252
regularized. See op_regularizer_manager for more detail.
5353
"""
5454

55-
def __init__(
56-
self,
57-
output_boundary: List[tf.Operation],
58-
hardware,
59-
batch_size=1,
60-
regularize_on_mask=True,
61-
alive_threshold=0.1,
62-
mask_as_alive_vector=True,
63-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
64-
decorator_parameters=None,
65-
input_boundary: List[tf.Operation] = None,
66-
force_group=None,
67-
regularizer_blacklist=None):
55+
def __init__(self,
56+
output_boundary: List[tf.Operation],
57+
hardware,
58+
batch_size=1,
59+
regularize_on_mask=True,
60+
alive_threshold=0.1,
61+
mask_as_alive_vector=True,
62+
regularizer_decorator: Optional[Type[
63+
generic_regularizers.OpRegularizer]] = None,
64+
decorator_parameters=None,
65+
input_boundary: Optional[List[tf.Operation]] = None,
66+
force_group=None,
67+
regularizer_blacklist=None):
6868

6969
self._hardware = hardware
7070
self._batch_size = batch_size
@@ -97,17 +97,17 @@ def cost_name(self):
9797
class GammaLatencyRegularizer(generic_regularizers.NetworkRegularizer):
9898
"""A NetworkRegularizer that targets latency using Gamma L1."""
9999

100-
def __init__(
101-
self,
102-
output_boundary: List[tf.Operation],
103-
gamma_threshold,
104-
hardware,
105-
batch_size=1,
106-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
107-
decorator_parameters=None,
108-
input_boundary: List[tf.Operation] = None,
109-
force_group=None,
110-
regularizer_blacklist=None) -> None:
100+
def __init__(self,
101+
output_boundary: List[tf.Operation],
102+
gamma_threshold,
103+
hardware,
104+
batch_size=1,
105+
regularizer_decorator: Optional[Type[
106+
generic_regularizers.OpRegularizer]] = None,
107+
decorator_parameters=None,
108+
input_boundary: Optional[List[tf.Operation]] = None,
109+
force_group=None,
110+
regularizer_blacklist=None) -> None:
111111
"""Creates a GammaLatencyRegularizer object.
112112
113113
Latency cost and regularization loss is calculated for a specified hardware

morph_net/network_regularizers/logistic_sigmoid_regularizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import print_function
77

88
import abc
9-
from typing import Type, List
9+
from typing import List, Optional, Type
1010

1111
from morph_net.framework import generic_regularizers
1212
from morph_net.framework import logistic_sigmoid_source_op_handler as ls_handler
@@ -23,17 +23,17 @@
2323
class LogisticSigmoidRegularizer(generic_regularizers.NetworkRegularizer):
2424
"""Base class for NetworkRegularizers that use probabilistic sampling."""
2525

26-
def __init__(
27-
self,
28-
output_boundary: List[tf.Operation],
29-
regularize_on_mask=True,
30-
alive_threshold=0.1,
31-
mask_as_alive_vector=True,
32-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
33-
decorator_parameters=None,
34-
input_boundary: List[tf.Operation] = None,
35-
force_group=None,
36-
regularizer_blacklist=None):
26+
def __init__(self,
27+
output_boundary: List[tf.Operation],
28+
regularize_on_mask=True,
29+
alive_threshold=0.1,
30+
mask_as_alive_vector=True,
31+
regularizer_decorator: Optional[Type[
32+
generic_regularizers.OpRegularizer]] = None,
33+
decorator_parameters=None,
34+
input_boundary: Optional[List[tf.Operation]] = None,
35+
force_group=None,
36+
regularizer_blacklist=None):
3737
"""Creates a LogisticSigmoidFlopsRegularizer object.
3838
3939
Args:

morph_net/network_regularizers/model_size_regularizer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import division
55
# [internal] enable type annotations
66
from __future__ import print_function
7-
from typing import Text, Type, List
7+
from typing import List, Optional, Text, Type
88

99
from morph_net.framework import batch_norm_source_op_handler
1010
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
@@ -40,15 +40,15 @@ def cost_name(self):
4040
class GammaModelSizeRegularizer(generic_regularizers.NetworkRegularizer):
4141
"""A NetworkRegularizer that targets model size using Gamma L1."""
4242

43-
def __init__(
44-
self,
45-
output_boundary: List[tf.Operation],
46-
gamma_threshold,
47-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
48-
decorator_parameters=None,
49-
input_boundary: List[tf.Operation] = None,
50-
force_group=None,
51-
regularizer_blacklist=None):
43+
def __init__(self,
44+
output_boundary: List[tf.Operation],
45+
gamma_threshold,
46+
regularizer_decorator: Optional[Type[
47+
generic_regularizers.OpRegularizer]] = None,
48+
decorator_parameters=None,
49+
input_boundary: Optional[List[tf.Operation]] = None,
50+
force_group=None,
51+
regularizer_blacklist=None):
5252
"""Creates a GammaModelSizeRegularizer object.
5353
5454
Args:
@@ -112,16 +112,16 @@ def cost_name(self):
112112
class GroupLassoModelSizeRegularizer(generic_regularizers.NetworkRegularizer):
113113
"""A NetworkRegularizer that targets model size using L1 group lasso."""
114114

115-
def __init__(
116-
self,
117-
output_boundary: List[tf.Operation],
118-
threshold,
119-
l1_fraction=0.0,
120-
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
121-
decorator_parameters=None,
122-
input_boundary: List[tf.Operation] = None,
123-
force_group: List[Text] = None,
124-
regularizer_blacklist: List[Text] = None):
115+
def __init__(self,
116+
output_boundary: List[tf.Operation],
117+
threshold,
118+
l1_fraction=0.0,
119+
regularizer_decorator: Optional[Type[
120+
generic_regularizers.OpRegularizer]] = None,
121+
decorator_parameters=None,
122+
input_boundary: Optional[List[tf.Operation]] = None,
123+
force_group: Optional[List[Text]] = None,
124+
regularizer_blacklist: Optional[List[Text]] = None):
125125
"""Creates a GroupLassoModelSizeRegularizer object.
126126
127127
Args:

0 commit comments

Comments
 (0)