From f00d1975bb032cc6f93341823b876cd54309c491 Mon Sep 17 00:00:00 2001 From: mhucka <1450019+mhucka@users.noreply.github.com> Date: Fri, 27 Mar 2026 05:27:25 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20[testing=20improvement]=20Add=20?= =?UTF-8?q?tests=20for=20random=5Fsymbol=5Fcircuit=5Fresolver=5Fbatch=20an?= =?UTF-8?q?d=20random=5Fcircuit=5Fresolver=5Fbatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive unit tests to `tensorflow_quantum/python/util_test.py` to verify the output shapes and types of `random_symbol_circuit_resolver_batch` and `random_circuit_resolver_batch`. These tests ensure that the batch generators return the correct number of `cirq.Circuit` and `cirq.ParamResolver` objects, and that symbols are correctly present in the resolvers when applicable. --- tensorflow_quantum/python/util_test.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..3edfa92b1 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -428,6 +428,37 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_random_circuit_resolver_batch(self): + """Confirm that random_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + batch_size = 5 + circuits, resolvers = util.random_circuit_resolver_batch( + qubits, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), 0) + + def test_random_symbol_circuit_resolver_batch(self): + """Confirm that random_symbol_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + symbols = [sympy.Symbol('a'), sympy.Symbol('b')] + batch_size = 5 + circuits, resolvers = util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), len(symbols)) + for symbol in symbols: + self.assertIn(symbol, resolver.param_dict) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work."""