diff --git a/glue/sample/requirements.txt b/glue/sample/requirements.txt index 4b0c6c5a3..0703ff082 100644 --- a/glue/sample/requirements.txt +++ b/glue/sample/requirements.txt @@ -2,3 +2,4 @@ matplotlib numpy stim scipy +packaging diff --git a/glue/sample/src/sinter/_decoding/_decoding.py b/glue/sample/src/sinter/_decoding/_decoding.py index 1e54f87ef..af5d9427b 100644 --- a/glue/sample/src/sinter/_decoding/_decoding.py +++ b/glue/sample/src/sinter/_decoding/_decoding.py @@ -180,6 +180,8 @@ def sample_decode(*, decoder: The name of the decoder to use. Allowed values are: "pymatching": Use pymatching min-weight-perfect-match decoder. + "pymatching-correlated": + Use pymatching min-weight-perfect-match decoder with correlated decoding enabled. "internal": Use internal decoder with uncorrelated decoding. "internal_correlated": diff --git a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py index 92d8d49dd..93ffa584f 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py @@ -12,6 +12,7 @@ BUILT_IN_DECODERS: Dict[str, Decoder] = { 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), + 'pymatching-correlated': PyMatchingDecoder(use_correlated_decoding=True), 'fusion_blossom': FusionBlossomDecoder(), # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) 'hypergraph_union_find': HyperUFDecoder(), diff --git a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py index b57bb32bc..7e2da1dcd 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py +++ b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py @@ -1,26 +1,45 @@ from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder +def check_pymatching_version_for_correlated_decoding(pymatching): + import packaging.version + if packaging.version.parse(pymatching.__version__) < packaging.version.parse("2.3.1"): + raise ValueError( + "Pymatching version must be at least 2.3.1 for correlated decoding.\n" + f"Installed version: {pymatching.__version__}\n" + "To fix this, install a newer version of pymatching into your environment.\n" + "For example, if you are using pip, run `pip install pymatching --upgrade`.\n" + ) + + class PyMatchingCompiledDecoder(CompiledDecoder): - def __init__(self, matcher: 'pymatching.Matching'): + def __init__(self, matcher: 'pymatching.Matching', use_correlated_decoding: bool): self.matcher = matcher + self.use_correlated_decoding = use_correlated_decoding def decode_shots_bit_packed( self, *, bit_packed_detection_event_data: 'np.ndarray', ) -> 'np.ndarray': + kwargs = {} + if self.use_correlated_decoding: + kwargs['enable_correlations'] = True return self.matcher.decode_batch( shots=bit_packed_detection_event_data, bit_packed_shots=True, bit_packed_predictions=True, return_weights=False, + **kwargs, ) class PyMatchingDecoder(Decoder): """Use pymatching to predict observables from detection events.""" + def __init__(self, use_correlated_decoding: bool = False): + self.use_correlated_decoding = use_correlated_decoding + def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder: try: import pymatching @@ -31,7 +50,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled "For example, if you are using pip, run `pip install pymatching`.\n" ) from ex - return PyMatchingCompiledDecoder(pymatching.Matching.from_detector_error_model(dem)) + kwargs = {} + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + kwargs['enable_correlations'] = True + return PyMatchingCompiledDecoder( + pymatching.Matching.from_detector_error_model(dem, **kwargs), + use_correlated_decoding=self.use_correlated_decoding, + ) def decode_via_files(self, *, @@ -59,7 +85,9 @@ def decode_via_files(self, if not hasattr(pymatching, 'cli'): raise ValueError(""" + The installed version of pymatching has no `pymatching.cli` method. + sinter requires pymatching 2.1.0 or later. If you're using pip to install packages, this can be fixed by running @@ -69,13 +97,18 @@ def decode_via_files(self, """) - result = pymatching.cli(command_line_args=[ + args = [ "predict", "--dem", str(dem_path), "--in", str(dets_b8_in_path), "--in_format", "b8", "--out", str(obs_predictions_b8_out_path), "--out_format", "b8", - ]) + ] + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + args.append("--enable_correlations") + + result = pymatching.cli(command_line_args=args) if result: - raise ValueError("pymatching.cli returned a non-zero exit code") + raise ValueError("pymatching.cli returned a non-zero exit code") \ No newline at end of file diff --git a/glue/sample/src/sinter/_decoding/_decoding_test.py b/glue/sample/src/sinter/_decoding/_decoding_test.py index cd4e28d0d..26678a182 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding/_decoding_test.py @@ -233,6 +233,8 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo @pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) def test_post_selection(decoder: str, force_streaming: Optional[bool]): + if decoder == 'pymatching-correlated': + pytest.skip("Correlated matching does not support error probabilities > 0.5 in from_detector_error_model") circuit = stim.Circuit(""" X_ERROR(0.6) 0 M 0