@@ -86,6 +86,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
8686 out_arr [i ][j ] = np .abs (np .vdot (final_wf , internal_wf ))** 2
8787
8888 self .assertAllClose (out , out_arr , atol = 1e-5 )
89+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
8990
9091 @parameterized .parameters ([
9192 {
@@ -138,6 +139,7 @@ def test_correctness_without_symbols(self, n_qubits, batch_size,
138139 out_arr [i ][j ] = np .abs (np .vdot (final_wf , internal_wf ))** 2
139140
140141 self .assertAllClose (out , out_arr , atol = 1e-5 )
142+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
141143
142144 def test_correctness_empty (self ):
143145 """Tests the fidelity with empty circuits."""
@@ -151,6 +153,7 @@ def test_correctness_empty(self):
151153 other_program )
152154 expected = np .array ([[1.0 ]], dtype = np .complex64 )
153155 self .assertAllClose (out , expected )
156+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
154157
155158 qubit = cirq .GridQubit (0 , 0 )
156159 non_empty_circuit = util .convert_to_tensor (
@@ -235,6 +238,7 @@ def test_tf_gradient_correctness_with_symbols(self, n_qubits, batch_size,
235238 out_arr [i ][k ] += grad_fid
236239
237240 self .assertAllClose (out , out_arr , atol = 1e-3 )
241+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
238242
239243 @parameterized .parameters ([
240244 {
@@ -272,6 +276,7 @@ def test_tf_gradient_correctness_without_symbols(self, n_qubits, batch_size,
272276 other_programs )
273277 out = tape .gradient (ip , symbol_values )
274278 self .assertAllClose (out , tf .zeros_like (symbol_values ), atol = 1e-3 )
279+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
275280
276281 def test_correctness_no_circuit (self ):
277282 """Test the inner product between no circuits."""
@@ -284,6 +289,7 @@ def test_correctness_no_circuit(self):
284289 out = fidelity_op .fidelity (empty_circuit , empty_symbols , empty_values ,
285290 other_program )
286291 self .assertShapeEqual (np .zeros ((0 , 0 )), out )
292+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
287293
288294 def test_tf_gradient_correctness_no_circuit (self ):
289295 """Test the inner product grad between no circuits."""
@@ -299,6 +305,7 @@ def test_tf_gradient_correctness_no_circuit(self):
299305 empty_values , other_program )
300306
301307 self .assertShapeEqual (np .zeros ((0 , 0 )), out )
308+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
302309
303310
304311if __name__ == "__main__" :
0 commit comments