1+
12# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
23#
34# Licensed under the Apache License, Version 2.0 (the "License");
1920results can be visualized using tools like TensorBoard.
2021"""
2122
23+ from typing import Any
2224from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
2325from tensorflow import keras
2426import tensorflow .compat .v1 as tf
@@ -40,41 +42,50 @@ class ExampleParser(keras.layers.Layer):
4042
4143 def __init__ (self , input_feature_key ):
4244 self ._input_feature_key = input_feature_key
45+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
4346 super ().__init__ ()
4447
48+ def compute_output_shape (self , input_shape : Any ):
49+ return [1 , 1 ]
50+
4551 def call (self , serialized_examples ):
4652 def get_feature (serialized_example ):
4753 parsed_example = tf .io .parse_single_example (
4854 serialized_example , features = FEATURE_MAP
4955 )
5056 return parsed_example [self ._input_feature_key ]
51-
57+ serialized_examples = tf . cast ( serialized_examples , tf . string )
5258 return tf .map_fn (get_feature , serialized_examples )
5359
5460
55- class ExampleModel (keras .Model ):
56- """A Example Keras NLP model ."""
61+ class Reshaper (keras .layers . Layer ):
62+ """A Keras layer that reshapes the input ."""
5763
58- def __init__ (self , input_feature_key ):
59- super ().__init__ ()
60- self .parser = ExampleParser (input_feature_key )
61- self .text_vectorization = keras .layers .TextVectorization (
62- max_tokens = 32 ,
63- output_mode = 'int' ,
64- output_sequence_length = 32 ,
65- )
66- self .text_vectorization .adapt (
67- ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
68- )
69- self .dense1 = keras .layers .Dense (32 , activation = 'relu' )
70- self .dense2 = keras .layers .Dense (1 )
71-
72- def call (self , inputs , training = True , mask = None ):
73- parsed_example = self .parser (inputs )
74- text_vector = self .text_vectorization (parsed_example )
75- output1 = self .dense1 (tf .cast (text_vector , tf .float32 ))
76- output2 = self .dense2 (output1 )
77- return output2
64+ def call (self , inputs ):
65+ return tf .reshape (inputs , (1 , 32 ))
66+
67+
68+ def get_example_model (input_feature_key : str ):
69+ """Returns a Keras model for testing."""
70+ parser = ExampleParser (input_feature_key )
71+ text_vectorization = keras .layers .TextVectorization (
72+ max_tokens = 32 ,
73+ output_mode = 'int' ,
74+ output_sequence_length = 32 ,
75+ )
76+ text_vectorization .adapt (
77+ ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
78+ )
79+ dense1 = keras .layers .Dense (32 , activation = 'relu' )
80+ dense2 = keras .layers .Dense (1 )
81+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
82+ parsed_example = parser (inputs )
83+ text_vector = text_vectorization (parsed_example )
84+ text_vector = Reshaper ()(text_vector )
85+ text_vector = tf .cast (text_vector , tf .float32 )
86+ output1 = dense1 (text_vector )
87+ output2 = dense2 (output1 )
88+ return tf .keras .Model (inputs = inputs , outputs = output2 )
7889
7990
8091def evaluate_model (
@@ -83,6 +94,7 @@ def evaluate_model(
8394 tfma_eval_result_path ,
8495 eval_config ,
8596):
97+
8698 """Evaluate Model using Tensorflow Model Analysis.
8799
88100 Args:
0 commit comments