Skip to content

refactor: add backward compatibility with legacy hyperprior state dict#352

Closed
studyingeugene wants to merge 1 commit intoInterDigitalInc:masterfrom
studyingeugene:master
Closed

refactor: add backward compatibility with legacy hyperprior state dict#352
studyingeugene wants to merge 1 commit intoInterDigitalInc:masterfrom
studyingeugene:master

Conversation

@studyingeugene
Copy link
Copy Markdown
Contributor

@studyingeugene studyingeugene commented Apr 3, 2026

Summary

This PR restores backward compatibility when loading legacy checkpoints for models using HyperpriorLatentCodec.

Problem

Older checkpoints use the nested hyperprior layout:

  • latent_codec.hyper.h_a.*
  • latent_codec.hyper.h_s.*
  • latent_codec.hyper.entropy_bottleneck.*

Current checkpoints use the refactored layout:

  • latent_codec.h_a.*
  • latent_codec.h_s.*
  • latent_codec.z.entropy_bottleneck.*

CompressionModel.load_state_dict() already handles some legacy EntropyBottleneck key conversions, but it did not handle this module-path migration. As a result, loading old checkpoints in the current code raises strict key mismatches.

What Changed

This PR keeps the fix minimal and limited to checkpoint loading:

  • add a legacy hyperprior key remap helper in compressai/models/utils.py
  • apply that remap from CompressionModel.load_state_dict() in compressai/models/base.py
  • keep the existing buffer update and legacy EntropyBottleneck remap logic unchanged

The remap is:

  • ...hyper.h_a.* -> ...h_a.*
  • ...hyper.h_s.* -> ...h_s.*
  • ...hyper.entropy_bottleneck.* -> ...z.entropy_bottleneck.*

Affected Models

Confirmed in this repository:

  • cheng2020-anchor-checkerboard
  • elic2022-official
  • elic2022-chandelier

Validation

Using checkpoints exported from commit 6cc371fde518e61bb9953a00a1dbec1b858e899f, I verified that the patched current code loads them successfully for:

  • cheng2020-anchor-checkerboard
  • elic2022-official
  • elic2022-chandelier

and that the resulting current-format state_dict matches the expected remapped tensors exactly.

Verification result:

{
  "cheng2020-anchor-checkerboard": {
    "loaded": true,
    "num_keys": 160,
    "mismatches": 0
  },
  "elic-official": {
    "loaded": true,
    "num_keys": 412,
    "mismatches": 0
  },
  "elic-chandelier": {
    "loaded": true,
    "num_keys": 412,
    "mismatches": 0
  }
}

Scope

This change only restores backward compatibility for loading legacy checkpoints into the current code.

Thanks for your time reviewing this PR.

@YodaEmbedding
Copy link
Copy Markdown
Contributor

YodaEmbedding commented Apr 3, 2026

Thanks, I had forgotten about this.

However, I hesitate to add complexity to load_state_dict(). This only affects users loading previously trained ELIC checkpoints after upgrading CompressAI. I think a better solution would be for downstream users to either pin to a particular CompressAI version or do a one-time upgrade of those checkpoints with a small conversion script.

Click for example python script
from __future__ import annotations

import argparse
from collections.abc import Mapping
from pathlib import Path

import torch


def remap_legacy_hyperprior_keys(
    state_dict: Mapping[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
    remapped: dict[str, torch.Tensor] = {}

    for key, value in state_dict.items():
        new_key = key

        if ".hyper.h_a." in key:
            new_key = key.replace(".hyper.h_a.", ".h_a.")
        elif ".hyper.h_s." in key:
            new_key = key.replace(".hyper.h_s.", ".h_s.")
        elif ".hyper.entropy_bottleneck." in key:
            new_key = key.replace(
                ".hyper.entropy_bottleneck.",
                ".z.entropy_bottleneck.",
            )

        remapped[new_key] = value

    return remapped


def convert_checkpoint(checkpoint: object) -> object:
    if isinstance(checkpoint, dict):
        converted = dict(checkpoint)

        if "state_dict" in converted and isinstance(converted["state_dict"], Mapping):
            converted["state_dict"] = remap_legacy_hyperprior_keys(converted["state_dict"])
            return converted

        if all(isinstance(k, str) for k in converted):
            return remap_legacy_hyperprior_keys(converted)

    raise TypeError("Unsupported checkpoint format")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("input", type=Path)
    parser.add_argument("output", type=Path)
    args = parser.parse_args()

    checkpoint = torch.load(args.input, map_location="cpu")
    converted = convert_checkpoint(checkpoint)
    torch.save(converted, args.output)


if __name__ == "__main__":
    main()

(Disclaimer: generated code.)

We currently do not provide pretrained checkpoints for ELIC or proper benchmarks, so this is still somewhat experimental.

@studyingeugene
Copy link
Copy Markdown
Contributor Author

Thank you for your review.
I agree that it is preferable for downstream users to adapt their code to CompressAI.
I'll close this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants