2323
2424import tensorflow as tf
2525from tensorflow_privacy .privacy .fast_gradient_clipping import gradient_clipping_utils
26+ from tensorflow_privacy .privacy .fast_gradient_clipping import layer_registry as lr
2627
2728
28- def get_registry_generator_fn (tape , layer_registry ):
29+ def get_registry_generator_fn (
30+ tape : tf .GradientTape , layer_registry : lr .LayerRegistry
31+ ):
2932 """Creates the generator function for `compute_gradient_norms()`."""
3033 if layer_registry is None :
3134 # Needed for backwards compatibility.
@@ -53,7 +56,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
5356 return registry_generator_fn
5457
5558
56- def compute_gradient_norms (input_model , x_batch , y_batch , layer_registry ):
59+ def compute_gradient_norms (
60+ input_model : tf .keras .Model ,
61+ x_batch : tf .Tensor ,
62+ y_batch : tf .Tensor ,
63+ layer_registry : lr .LayerRegistry ,
64+ ):
5765 """Computes the per-example loss gradient norms for given data.
5866
5967 Applies a variant of the approach given in
@@ -106,7 +114,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106114 return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
107115
108116
109- def compute_clip_weights (l2_norm_clip , gradient_norms ):
117+ def compute_clip_weights (l2_norm_clip : float , gradient_norms : tf . Tensor ):
110118 """Computes the per-example loss/clip weights for clipping.
111119
112120 When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +140,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132140
133141
134142def compute_pred_and_clipped_gradients (
135- input_model , x_batch , y_batch , l2_norm_clip , layer_registry
143+ input_model : tf .keras .Model ,
144+ x_batch : tf .Tensor ,
145+ y_batch : tf .Tensor ,
146+ l2_norm_clip : float ,
147+ layer_registry : lr .LayerRegistry ,
136148):
137149 """Computes the per-example predictions and per-example clipped loss gradient.
138150
0 commit comments