From 41d9b9f05b7a6836be6880e88b6e28e72a0cd6bf Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 27 Mar 2026 05:00:40 +0000 Subject: [PATCH 1/3] Add tests for channels in gate_approx_eq `gate_approx_eq` in `tensorflow_quantum/python/util.py` had logic to handle quantum channels via `_channel_approx_eq`, but this path was not tested in `util_test.py`. This PR adds tests for the following scenarios: - All supported channels: `DepolarizingChannel`, `AsymmetricDepolarizingChannel`, `GeneralizedAmplitudeDampingChannel`, `AmplitudeDampingChannel`, `ResetChannel`, `PhaseDampingChannel`, `PhaseFlipChannel`, and `BitFlipChannel`. - Exact equality for each channel type. - Approximate equality within and outside the `atol` tolerance. - Type mismatches between different channels and between channels and non-channel gates. --- tensorflow_quantum/python/util_test.py | 132 ++++++++++++++++++++++++- 1 file changed, 130 insertions(+), 2 deletions(-) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..354157672 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -33,8 +33,8 @@ def _single_to_tensor(item): if not isinstance(item, (cirq.PauliSum, cirq.PauliString, cirq.Circuit)): - raise TypeError( - f"Item must be a Circuit or PauliSum. Got {type(item)}.") + raise TypeError("Item must be a Circuit or PauliSum. Got {}.".format( + type(item))) if isinstance(item, (cirq.PauliSum, cirq.PauliString)): return serializer.serialize_paulisum(item).SerializeToString( deterministic=True) @@ -347,6 +347,134 @@ def test_gate_approx_eq(self): util.gate_approx_eq( cirq.X, cirq.ops.ControlledGate(cirq.X, 2, [1, 0], [2, 2]))) + def test_gate_approx_eq_channels(self): + """Check valid TFQ channels for approximate equality.""" + atol = 1e-2 + + # DepolarizingChannel + self.assertTrue( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.DepolarizingChannel(0.1), + atol=atol)) + self.assertTrue( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.DepolarizingChannel(0.105), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.DepolarizingChannel(0.2), + atol=atol)) + + # AsymmetricDepolarizingChannel + self.assertTrue( + util.gate_approx_eq( + cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), + cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), + atol=atol)) + self.assertTrue( + util.gate_approx_eq( + cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), + cirq.AsymmetricDepolarizingChannel(0.105, 0.195, 0.305), + atol=atol)) + self.assertFalse( + util.gate_approx_eq( + cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), + cirq.AsymmetricDepolarizingChannel(0.2, 0.2, 0.3), + atol=atol)) + + # GeneralizedAmplitudeDampingChannel + self.assertTrue( + util.gate_approx_eq( + cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), + cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), + atol=atol)) + self.assertTrue( + util.gate_approx_eq( + cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), + cirq.GeneralizedAmplitudeDampingChannel(0.105, 0.205), + atol=atol)) + self.assertFalse( + util.gate_approx_eq( + cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), + cirq.GeneralizedAmplitudeDampingChannel(0.2, 0.2), + atol=atol)) + + # AmplitudeDampingChannel + self.assertTrue( + util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), + cirq.AmplitudeDampingChannel(0.1), + atol=atol)) + self.assertTrue( + util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), + cirq.AmplitudeDampingChannel(0.105), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), + cirq.AmplitudeDampingChannel(0.2), + atol=atol)) + + # ResetChannel + self.assertTrue( + util.gate_approx_eq(cirq.ResetChannel(), + cirq.ResetChannel(), + atol=atol)) + + # PhaseDampingChannel + self.assertTrue( + util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), + cirq.PhaseDampingChannel(0.1), + atol=atol)) + self.assertTrue( + util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), + cirq.PhaseDampingChannel(0.105), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), + cirq.PhaseDampingChannel(0.2), + atol=atol)) + + # PhaseFlipChannel + self.assertTrue( + util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), + cirq.PhaseFlipChannel(0.1), + atol=atol)) + self.assertTrue( + util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), + cirq.PhaseFlipChannel(0.105), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), + cirq.PhaseFlipChannel(0.2), + atol=atol)) + + # BitFlipChannel + self.assertTrue( + util.gate_approx_eq(cirq.BitFlipChannel(0.1), + cirq.BitFlipChannel(0.1), + atol=atol)) + self.assertTrue( + util.gate_approx_eq(cirq.BitFlipChannel(0.1), + cirq.BitFlipChannel(0.105), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.BitFlipChannel(0.1), + cirq.BitFlipChannel(0.2), + atol=atol)) + + # Mismatched types + self.assertFalse( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.BitFlipChannel(0.1), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.X, + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.X, + cirq.DepolarizingChannel(0.1), + atol=atol)) + def test_gate_approx_eq_error(self): """Confirms that bad inputs cause an error to be raised.""" # junk From ae899a6abcfe60c46adb485f291a9c8b1f4c74f1 Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 27 Mar 2026 05:20:00 +0000 Subject: [PATCH 2/3] Improve code Update error message, and streamline test code --- tensorflow_quantum/python/util_test.py | 133 ++++++------------------- 1 file changed, 29 insertions(+), 104 deletions(-) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 354157672..fcad5a49c 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -33,8 +33,8 @@ def _single_to_tensor(item): if not isinstance(item, (cirq.PauliSum, cirq.PauliString, cirq.Circuit)): - raise TypeError("Item must be a Circuit or PauliSum. Got {}.".format( - type(item))) + raise TypeError(f"Item must be a Circuit, PauliString, or PauliSum." + " Got {type(item)}.") if isinstance(item, (cirq.PauliSum, cirq.PauliString)): return serializer.serialize_paulisum(item).SerializeToString( deterministic=True) @@ -351,67 +351,34 @@ def test_gate_approx_eq_channels(self): """Check valid TFQ channels for approximate equality.""" atol = 1e-2 - # DepolarizingChannel - self.assertTrue( - util.gate_approx_eq(cirq.DepolarizingChannel(0.1), - cirq.DepolarizingChannel(0.1), - atol=atol)) - self.assertTrue( - util.gate_approx_eq(cirq.DepolarizingChannel(0.1), - cirq.DepolarizingChannel(0.105), - atol=atol)) - self.assertFalse( - util.gate_approx_eq(cirq.DepolarizingChannel(0.1), - cirq.DepolarizingChannel(0.2), - atol=atol)) - - # AsymmetricDepolarizingChannel - self.assertTrue( - util.gate_approx_eq( - cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), - cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), - atol=atol)) - self.assertTrue( - util.gate_approx_eq( - cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), - cirq.AsymmetricDepolarizingChannel(0.105, 0.195, 0.305), - atol=atol)) - self.assertFalse( - util.gate_approx_eq( - cirq.AsymmetricDepolarizingChannel(0.1, 0.2, 0.3), - cirq.AsymmetricDepolarizingChannel(0.2, 0.2, 0.3), - atol=atol)) - - # GeneralizedAmplitudeDampingChannel - self.assertTrue( - util.gate_approx_eq( - cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), - cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), - atol=atol)) - self.assertTrue( - util.gate_approx_eq( - cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), - cirq.GeneralizedAmplitudeDampingChannel(0.105, 0.205), - atol=atol)) - self.assertFalse( - util.gate_approx_eq( - cirq.GeneralizedAmplitudeDampingChannel(0.1, 0.2), - cirq.GeneralizedAmplitudeDampingChannel(0.2, 0.2), - atol=atol)) + test_cases = [ + (cirq.DepolarizingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.AsymmetricDepolarizingChannel, (0.1, 0.2, 0.3), + (0.105, 0.195, 0.305), (0.2, 0.2, 0.3)), + (cirq.GeneralizedAmplitudeDampingChannel, (0.1, 0.2), + (0.105, 0.205), (0.2, 0.2)), + (cirq.AmplitudeDampingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.PhaseDampingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.PhaseFlipChannel, (0.1,), (0.105,), (0.2,)), + (cirq.BitFlipChannel, (0.1,), (0.105,), (0.2,)), + ] - # AmplitudeDampingChannel - self.assertTrue( - util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), - cirq.AmplitudeDampingChannel(0.1), - atol=atol)) - self.assertTrue( - util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), - cirq.AmplitudeDampingChannel(0.105), - atol=atol)) - self.assertFalse( - util.gate_approx_eq(cirq.AmplitudeDampingChannel(0.1), - cirq.AmplitudeDampingChannel(0.2), - atol=atol)) + for channel, exact_params, approx_params, unequal_params in test_cases: + with self.subTest(channel=channel.__name__): + gate1 = channel(*exact_params) + gate2_exact = channel(*exact_params) + gate2_approx = channel(*approx_params) + gate2_not_equal = channel(*unequal_params) + + # Exact equality + self.assertTrue( + util.gate_approx_eq(gate1, gate2_exact, atol=atol)) + # Approximate equality + self.assertTrue( + util.gate_approx_eq(gate1, gate2_approx, atol=atol)) + # Not equal + self.assertFalse( + util.gate_approx_eq(gate1, gate2_not_equal, atol=atol)) # ResetChannel self.assertTrue( @@ -419,48 +386,6 @@ def test_gate_approx_eq_channels(self): cirq.ResetChannel(), atol=atol)) - # PhaseDampingChannel - self.assertTrue( - util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), - cirq.PhaseDampingChannel(0.1), - atol=atol)) - self.assertTrue( - util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), - cirq.PhaseDampingChannel(0.105), - atol=atol)) - self.assertFalse( - util.gate_approx_eq(cirq.PhaseDampingChannel(0.1), - cirq.PhaseDampingChannel(0.2), - atol=atol)) - - # PhaseFlipChannel - self.assertTrue( - util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), - cirq.PhaseFlipChannel(0.1), - atol=atol)) - self.assertTrue( - util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), - cirq.PhaseFlipChannel(0.105), - atol=atol)) - self.assertFalse( - util.gate_approx_eq(cirq.PhaseFlipChannel(0.1), - cirq.PhaseFlipChannel(0.2), - atol=atol)) - - # BitFlipChannel - self.assertTrue( - util.gate_approx_eq(cirq.BitFlipChannel(0.1), - cirq.BitFlipChannel(0.1), - atol=atol)) - self.assertTrue( - util.gate_approx_eq(cirq.BitFlipChannel(0.1), - cirq.BitFlipChannel(0.105), - atol=atol)) - self.assertFalse( - util.gate_approx_eq(cirq.BitFlipChannel(0.1), - cirq.BitFlipChannel(0.2), - atol=atol)) - # Mismatched types self.assertFalse( util.gate_approx_eq(cirq.DepolarizingChannel(0.1), From e147592aa11f7144cd66d445244dd7fa3b90978d Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 27 Mar 2026 05:21:32 +0000 Subject: [PATCH 3/3] Fix incorrect f-string placement --- tensorflow_quantum/python/util_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index fcad5a49c..362ca71fc 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -33,8 +33,8 @@ def _single_to_tensor(item): if not isinstance(item, (cirq.PauliSum, cirq.PauliString, cirq.Circuit)): - raise TypeError(f"Item must be a Circuit, PauliString, or PauliSum." - " Got {type(item)}.") + raise TypeError("Item must be a Circuit, PauliString, or PauliSum." + f" Got {type(item)}.") if isinstance(item, (cirq.PauliSum, cirq.PauliString)): return serializer.serialize_paulisum(item).SerializeToString( deterministic=True)