|
161 | 161 | "outputs": [], |
162 | 162 | "source": [ |
163 | 163 | "def build_model(w, h, class_num):\n", |
164 | | - " url = 'https://tfhub.dev/deepmind/ganeval-cifar10-convnet/1'\n", |
| 164 | + " url = \"https://www.kaggle.com/models/deepmind/ganeval-cifar10-convnet/frameworks/TensorFlow1/variations/ganeval-cifar10-convnet/versions/1\"\n", |
165 | 165 | " feature_extractor_layer = hub.KerasLayer(url, input_shape = (w, h, 3))\n", |
166 | 166 | " feature_extractor_layer.trainable = False\n", |
167 | 167 | "\n", |
|
421 | 421 | "from neural_compressor.data import DataLoader\n", |
422 | 422 | "from neural_compressor.quantization import fit\n", |
423 | 423 | "from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion\n", |
| 424 | + "from neural_compressor import Metric\n", |
| 425 | + "\n", |
424 | 426 | "\n", |
425 | 427 | "def auto_tune(input_graph_path, batch_size, int8_pb_file):\n", |
426 | 428 | " dataset = Dataset()\n", |
|
434 | 436 | " tolerable_loss=0.01 \n", |
435 | 437 | " )\n", |
436 | 438 | " )\n", |
| 439 | + "\n", |
| 440 | + " top1 = Metric(name=\"topk\", k=1)\n", |
| 441 | + " \n", |
437 | 442 | " q_model = fit(\n", |
438 | 443 | " model=input_graph_path,\n", |
439 | 444 | " conf=config,\n", |
440 | 445 | " calib_dataloader=dataloader,\n", |
441 | | - " eval_dataloader=dataloader\n", |
| 446 | + " eval_dataloader=dataloader,\n", |
| 447 | + " eval_metric=top1\n", |
442 | 448 | " )\n", |
443 | 449 | "\n", |
444 | 450 | " return q_model\n", |
|
588 | 594 | "}\n", |
589 | 595 | "```" |
590 | 596 | ] |
| 597 | + }, |
| 598 | + { |
| 599 | + "cell_type": "code", |
| 600 | + "execution_count": null, |
| 601 | + "id": "e0fbd1da-2731-4398-a327-9908c87c8c5f", |
| 602 | + "metadata": {}, |
| 603 | + "outputs": [], |
| 604 | + "source": [ |
| 605 | + "!which python " |
| 606 | + ] |
| 607 | + }, |
| 608 | + { |
| 609 | + "cell_type": "code", |
| 610 | + "execution_count": null, |
| 611 | + "id": "ae14f078-3414-45c5-bb98-b77eb33c0070", |
| 612 | + "metadata": {}, |
| 613 | + "outputs": [], |
| 614 | + "source": [] |
591 | 615 | } |
592 | 616 | ], |
593 | 617 | "metadata": { |
|
606 | 630 | "name": "python", |
607 | 631 | "nbconvert_exporter": "python", |
608 | 632 | "pygments_lexer": "ipython3", |
609 | | - "version": "3.9.15" |
| 633 | + "version": "3.9.18" |
610 | 634 | } |
611 | 635 | }, |
612 | 636 | "nbformat": 4, |
|
0 commit comments