refactor: add backward compatibility with legacy hyperprior state dict#352
Closed
studyingeugene wants to merge 1 commit intoInterDigitalInc:masterfrom
Closed
refactor: add backward compatibility with legacy hyperprior state dict#352studyingeugene wants to merge 1 commit intoInterDigitalInc:masterfrom
studyingeugene wants to merge 1 commit intoInterDigitalInc:masterfrom
Conversation
Contributor
|
Thanks, I had forgotten about this. However, I hesitate to add complexity to load_state_dict(). This only affects users loading previously trained ELIC checkpoints after upgrading CompressAI. I think a better solution would be for downstream users to either pin to a particular CompressAI version or do a one-time upgrade of those checkpoints with a small conversion script. Click for example python scriptfrom __future__ import annotations
import argparse
from collections.abc import Mapping
from pathlib import Path
import torch
def remap_legacy_hyperprior_keys(
state_dict: Mapping[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
remapped: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
new_key = key
if ".hyper.h_a." in key:
new_key = key.replace(".hyper.h_a.", ".h_a.")
elif ".hyper.h_s." in key:
new_key = key.replace(".hyper.h_s.", ".h_s.")
elif ".hyper.entropy_bottleneck." in key:
new_key = key.replace(
".hyper.entropy_bottleneck.",
".z.entropy_bottleneck.",
)
remapped[new_key] = value
return remapped
def convert_checkpoint(checkpoint: object) -> object:
if isinstance(checkpoint, dict):
converted = dict(checkpoint)
if "state_dict" in converted and isinstance(converted["state_dict"], Mapping):
converted["state_dict"] = remap_legacy_hyperprior_keys(converted["state_dict"])
return converted
if all(isinstance(k, str) for k in converted):
return remap_legacy_hyperprior_keys(converted)
raise TypeError("Unsupported checkpoint format")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("input", type=Path)
parser.add_argument("output", type=Path)
args = parser.parse_args()
checkpoint = torch.load(args.input, map_location="cpu")
converted = convert_checkpoint(checkpoint)
torch.save(converted, args.output)
if __name__ == "__main__":
main()(Disclaimer: generated code.) We currently do not provide pretrained checkpoints for ELIC or proper benchmarks, so this is still somewhat experimental. |
Contributor
Author
|
Thank you for your review. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR restores backward compatibility when loading legacy checkpoints for models using
HyperpriorLatentCodec.Problem
Older checkpoints use the nested hyperprior layout:
latent_codec.hyper.h_a.*latent_codec.hyper.h_s.*latent_codec.hyper.entropy_bottleneck.*Current checkpoints use the refactored layout:
latent_codec.h_a.*latent_codec.h_s.*latent_codec.z.entropy_bottleneck.*CompressionModel.load_state_dict()already handles some legacyEntropyBottleneckkey conversions, but it did not handle this module-path migration. As a result, loading old checkpoints in the current code raises strict key mismatches.What Changed
This PR keeps the fix minimal and limited to checkpoint loading:
compressai/models/utils.pyCompressionModel.load_state_dict()incompressai/models/base.pyEntropyBottleneckremap logic unchangedThe remap is:
...hyper.h_a.* -> ...h_a.*...hyper.h_s.* -> ...h_s.*...hyper.entropy_bottleneck.* -> ...z.entropy_bottleneck.*Affected Models
Confirmed in this repository:
cheng2020-anchor-checkerboardelic2022-officialelic2022-chandelierValidation
Using checkpoints exported from commit
6cc371fde518e61bb9953a00a1dbec1b858e899f, I verified that the patched current code loads them successfully for:cheng2020-anchor-checkerboardelic2022-officialelic2022-chandelierand that the resulting current-format
state_dictmatches the expected remapped tensors exactly.Verification result:
{ "cheng2020-anchor-checkerboard": { "loaded": true, "num_keys": 160, "mismatches": 0 }, "elic-official": { "loaded": true, "num_keys": 412, "mismatches": 0 }, "elic-chandelier": { "loaded": true, "num_keys": 412, "mismatches": 0 } }Scope
This change only restores backward compatibility for loading legacy checkpoints into the current code.
Thanks for your time reviewing this PR.