|
51 | 51 | ) |
52 | 52 | from diffusers.optimization import get_scheduler |
53 | 53 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid |
| 54 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
54 | 55 | from diffusers.utils.import_utils import is_xformers_available |
55 | 56 | from diffusers.utils.torch_utils import is_compiled_module |
56 | 57 |
|
@@ -199,28 +200,32 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N |
199 | 200 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
200 | 201 | img_str += f"\n" |
201 | 202 |
|
202 | | - yaml = f""" |
203 | | ---- |
204 | | -license: openrail++ |
205 | | -base_model: {base_model} |
206 | | -tags: |
207 | | -- stable-diffusion-xl |
208 | | -- stable-diffusion-xl-diffusers |
209 | | -- text-to-image |
210 | | -- diffusers |
211 | | -- controlnet |
212 | | -inference: true |
213 | | ---- |
214 | | - """ |
215 | | - model_card = f""" |
| 203 | + model_description = f""" |
216 | 204 | # controlnet-{repo_id} |
217 | 205 |
|
218 | 206 | These are controlnet weights trained on {base_model} with new type of conditioning. |
219 | 207 | {img_str} |
220 | 208 | """ |
221 | 209 |
|
222 | | - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
223 | | - f.write(yaml + model_card) |
| 210 | + model_card = load_or_create_model_card( |
| 211 | + repo_id_or_path=repo_id, |
| 212 | + from_training=True, |
| 213 | + license="openrail++", |
| 214 | + base_model=base_model, |
| 215 | + model_description=model_description, |
| 216 | + inference=True, |
| 217 | + ) |
| 218 | + |
| 219 | + tags = [ |
| 220 | + "stable-diffusion-xl", |
| 221 | + "stable-diffusion-xl-diffusers", |
| 222 | + "text-to-image", |
| 223 | + "diffusers", |
| 224 | + "controlnet", |
| 225 | + ] |
| 226 | + model_card = populate_model_card(model_card, tags=tags) |
| 227 | + |
| 228 | + model_card.save(os.path.join(repo_folder, "README.md")) |
224 | 229 |
|
225 | 230 |
|
226 | 231 | def parse_args(input_args=None): |
|
0 commit comments