|
48 | 48 | FlaxUNet2DConditionModel, |
49 | 49 | ) |
50 | 50 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid |
| 51 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
51 | 52 |
|
52 | 53 |
|
53 | 54 | # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image |
@@ -145,28 +146,33 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N |
145 | 146 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
146 | 147 | img_str += f"\n" |
147 | 148 |
|
148 | | - yaml = f""" |
149 | | ---- |
150 | | -license: creativeml-openrail-m |
151 | | -base_model: {base_model} |
152 | | -tags: |
153 | | -- stable-diffusion |
154 | | -- stable-diffusion-diffusers |
155 | | -- text-to-image |
156 | | -- diffusers |
157 | | -- controlnet |
158 | | -- jax-diffusers-event |
159 | | -inference: true |
160 | | ---- |
161 | | - """ |
162 | | - model_card = f""" |
| 149 | + model_description = f""" |
163 | 150 | # controlnet- {repo_id} |
164 | 151 |
|
165 | 152 | These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n |
166 | 153 | {img_str} |
167 | 154 | """ |
168 | | - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
169 | | - f.write(yaml + model_card) |
| 155 | + |
| 156 | + model_card = load_or_create_model_card( |
| 157 | + repo_id_or_path=repo_id, |
| 158 | + from_training=True, |
| 159 | + license="creativeml-openrail-m", |
| 160 | + base_model=base_model, |
| 161 | + model_description=model_description, |
| 162 | + inference=True, |
| 163 | + ) |
| 164 | + |
| 165 | + tags = [ |
| 166 | + "stable-diffusion", |
| 167 | + "stable-diffusion-diffusers", |
| 168 | + "text-to-image", |
| 169 | + "diffusers", |
| 170 | + "controlnet", |
| 171 | + "jax-diffusers-event", |
| 172 | + ] |
| 173 | + model_card = populate_model_card(model_card, tags=tags) |
| 174 | + |
| 175 | + model_card.save(os.path.join(repo_folder, "README.md")) |
170 | 176 |
|
171 | 177 |
|
172 | 178 | def parse_args(): |
|
0 commit comments