diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..3edfa92b1 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -428,6 +428,37 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_random_circuit_resolver_batch(self): + """Confirm that random_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + batch_size = 5 + circuits, resolvers = util.random_circuit_resolver_batch( + qubits, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), 0) + + def test_random_symbol_circuit_resolver_batch(self): + """Confirm that random_symbol_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + symbols = [sympy.Symbol('a'), sympy.Symbol('b')] + batch_size = 5 + circuits, resolvers = util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), len(symbols)) + for symbol in symbols: + self.assertIn(symbol, resolver.param_dict) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work."""