Skip to content

Commit 84adca9

Browse files
authored
Handle qudits in drop_terminal_measurements (#6879)
* Only apply X's, not I's, for invert_mask in drop_terminal_measurements This allows drop_terminal_measurements to work for qudits, which failed on the `I` application, given invert_mask values are undefined for them anyway. * Revive IdentityGate but with explicit dimension, to fix bug where the qubit was being removed completely if a terminal measurement was the only thing it contained.
1 parent 2949afd commit 84adca9

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

cirq-core/cirq/transformers/measurement_transformers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,20 @@ def drop_terminal_measurements(
292292
def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
293293
if isinstance(op.gate, ops.MeasurementGate):
294294
return [
295-
ops.X(q) if b else ops.I(q) for q, b in zip(op.qubits, op.gate.full_invert_mask())
295+
(
296+
(ops.X if b else ops.I)
297+
if q.dimension == 2
298+
else (
299+
ops.MatrixGate(
300+
# Per SimulationState.measure(), swap 0,1 but leave other dims alone
301+
np.eye(q.dimension)[[1, 0, *range(2, q.dimension)]],
302+
qid_shape=(q.dimension,),
303+
)
304+
if b
305+
else ops.IdentityGate(qid_shape=(q.dimension,))
306+
)
307+
).on(q)
308+
for q, b in zip(op.qubits, op.gate.full_invert_mask())
296309
]
297310
return op
298311

cirq-core/cirq/transformers/measurement_transformers_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,41 @@ def test_drop_terminal():
759759
)
760760

761761

762+
def test_drop_terminal_qudit():
763+
q0, q1 = cirq.LineQid.range(2, dimension=3)
764+
circuit = cirq.Circuit(
765+
cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0, q1, key='m', invert_mask=[0, 1])))
766+
)
767+
dropped = cirq.drop_terminal_measurements(circuit)
768+
expected_inversion_matrix = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
769+
cirq.testing.assert_same_circuits(
770+
dropped,
771+
cirq.Circuit(
772+
cirq.CircuitOperation(
773+
cirq.FrozenCircuit(
774+
cirq.IdentityGate(qid_shape=(3,)).on(q0),
775+
cirq.MatrixGate(expected_inversion_matrix, qid_shape=(3,)).on(q1),
776+
)
777+
)
778+
),
779+
)
780+
# Verify behavior equivalent to simulator (invert_mask swaps 0,1 but leaves 2 alone)
781+
dropped.append(cirq.measure(q0, q1, key='m'))
782+
sim = cirq.Simulator()
783+
c0 = sim.simulate(circuit, initial_state=[0, 0])
784+
d0 = sim.simulate(dropped, initial_state=[0, 0])
785+
assert np.all(c0.measurements['m'] == [0, 1])
786+
assert np.all(d0.measurements['m'] == [0, 1])
787+
c1 = sim.simulate(circuit, initial_state=[1, 1])
788+
d1 = sim.simulate(dropped, initial_state=[1, 1])
789+
assert np.all(c1.measurements['m'] == [1, 0])
790+
assert np.all(d1.measurements['m'] == [1, 0])
791+
c2 = sim.simulate(circuit, initial_state=[2, 2])
792+
d2 = sim.simulate(dropped, initial_state=[2, 2])
793+
assert np.all(c2.measurements['m'] == [2, 2])
794+
assert np.all(d2.measurements['m'] == [2, 2])
795+
796+
762797
def test_drop_terminal_nonterminal_error():
763798
q0, q1 = cirq.LineQubit.range(2)
764799
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)