diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..362ca71fc 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -33,8 +33,8 @@ def _single_to_tensor(item): if not isinstance(item, (cirq.PauliSum, cirq.PauliString, cirq.Circuit)): - raise TypeError( - f"Item must be a Circuit or PauliSum. Got {type(item)}.") + raise TypeError("Item must be a Circuit, PauliString, or PauliSum." + f" Got {type(item)}.") if isinstance(item, (cirq.PauliSum, cirq.PauliString)): return serializer.serialize_paulisum(item).SerializeToString( deterministic=True) @@ -347,6 +347,59 @@ def test_gate_approx_eq(self): util.gate_approx_eq( cirq.X, cirq.ops.ControlledGate(cirq.X, 2, [1, 0], [2, 2]))) + def test_gate_approx_eq_channels(self): + """Check valid TFQ channels for approximate equality.""" + atol = 1e-2 + + test_cases = [ + (cirq.DepolarizingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.AsymmetricDepolarizingChannel, (0.1, 0.2, 0.3), + (0.105, 0.195, 0.305), (0.2, 0.2, 0.3)), + (cirq.GeneralizedAmplitudeDampingChannel, (0.1, 0.2), + (0.105, 0.205), (0.2, 0.2)), + (cirq.AmplitudeDampingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.PhaseDampingChannel, (0.1,), (0.105,), (0.2,)), + (cirq.PhaseFlipChannel, (0.1,), (0.105,), (0.2,)), + (cirq.BitFlipChannel, (0.1,), (0.105,), (0.2,)), + ] + + for channel, exact_params, approx_params, unequal_params in test_cases: + with self.subTest(channel=channel.__name__): + gate1 = channel(*exact_params) + gate2_exact = channel(*exact_params) + gate2_approx = channel(*approx_params) + gate2_not_equal = channel(*unequal_params) + + # Exact equality + self.assertTrue( + util.gate_approx_eq(gate1, gate2_exact, atol=atol)) + # Approximate equality + self.assertTrue( + util.gate_approx_eq(gate1, gate2_approx, atol=atol)) + # Not equal + self.assertFalse( + util.gate_approx_eq(gate1, gate2_not_equal, atol=atol)) + + # ResetChannel + self.assertTrue( + util.gate_approx_eq(cirq.ResetChannel(), + cirq.ResetChannel(), + atol=atol)) + + # Mismatched types + self.assertFalse( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.BitFlipChannel(0.1), + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.DepolarizingChannel(0.1), + cirq.X, + atol=atol)) + self.assertFalse( + util.gate_approx_eq(cirq.X, + cirq.DepolarizingChannel(0.1), + atol=atol)) + def test_gate_approx_eq_error(self): """Confirms that bad inputs cause an error to be raised.""" # junk