Commit 8cc313f
authored
[TORCH] Add Kullback-Leibler divergence loss support (#4204)
This PR takes care of #4203.
- e2e support of **aten.kl_div** op supporting all reduction modes
(`mean, sum, batchmean, none`)
- `reduction: batchmean` requires special handling by calling op with
`sum` and then dividing it by input `batch_size`.
Some tests are failing and are marked either in expected failures or
crashing set.
- **config=linalg** | **RuntimeError**: attribute lookup is not defined
on builtin | **LINALG_XFAIL_SET**
- **config=torchdynamo** | **Error**: failed to legalize operation
'`torch.aten.xlogy.Tensor`' | **TORCHDYNAMO_CRASHING_SET**
- **config=onnx** | **RuntimeError**: aten::div() Expected a value of
type 'number' for argument 'other' but instead found type 'Tensor'
Position: 1
Value: tensor(1)
Declaration: aten::div.Scalar(Tensor self, Scalar other) -> Tensor
Cast error details: Cannot cast tensor(1) to number | **ONNX_XFAIL_SET**
- **config=onnx_tosa** | **Error**: failed to legalize operation
'`torch.aten.size.int`' that was explicitly marked illegal |
**ONNX_TOSA_XFAIL_SET**
---------
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>1 parent 867eb39 commit 8cc313f
File tree
9 files changed
+339
-0
lines changed- include/torch-mlir/Dialect/Torch/IR
- lib/Dialect/Torch/Transforms
- projects/pt1
- e2e_testing
- python
- torch_mlir_e2e_test/test_suite
- torch_mlir/jit_ir_importer/build_tools
9 files changed
+339
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9506 | 9506 | | |
9507 | 9507 | | |
9508 | 9508 | | |
| 9509 | + | |
| 9510 | + | |
| 9511 | + | |
| 9512 | + | |
| 9513 | + | |
| 9514 | + | |
| 9515 | + | |
| 9516 | + | |
| 9517 | + | |
| 9518 | + | |
| 9519 | + | |
| 9520 | + | |
| 9521 | + | |
| 9522 | + | |
| 9523 | + | |
| 9524 | + | |
| 9525 | + | |
| 9526 | + | |
| 9527 | + | |
| 9528 | + | |
| 9529 | + | |
| 9530 | + | |
| 9531 | + | |
| 9532 | + | |
| 9533 | + | |
| 9534 | + | |
9509 | 9535 | | |
9510 | 9536 | | |
9511 | 9537 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10692 | 10692 | | |
10693 | 10693 | | |
10694 | 10694 | | |
| 10695 | + | |
| 10696 | + | |
| 10697 | + | |
| 10698 | + | |
| 10699 | + | |
| 10700 | + | |
| 10701 | + | |
| 10702 | + | |
| 10703 | + | |
| 10704 | + | |
| 10705 | + | |
| 10706 | + | |
| 10707 | + | |
| 10708 | + | |
| 10709 | + | |
| 10710 | + | |
| 10711 | + | |
| 10712 | + | |
| 10713 | + | |
| 10714 | + | |
| 10715 | + | |
| 10716 | + | |
| 10717 | + | |
| 10718 | + | |
| 10719 | + | |
10695 | 10720 | | |
10696 | 10721 | | |
10697 | 10722 | | |
| |||
14575 | 14600 | | |
14576 | 14601 | | |
14577 | 14602 | | |
| 14603 | + | |
| 14604 | + | |
| 14605 | + | |
| 14606 | + | |
| 14607 | + | |
| 14608 | + | |
| 14609 | + | |
| 14610 | + | |
14578 | 14611 | | |
14579 | 14612 | | |
14580 | 14613 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10629 | 10629 | | |
10630 | 10630 | | |
10631 | 10631 | | |
| 10632 | + | |
| 10633 | + | |
| 10634 | + | |
| 10635 | + | |
| 10636 | + | |
| 10637 | + | |
| 10638 | + | |
| 10639 | + | |
| 10640 | + | |
| 10641 | + | |
| 10642 | + | |
| 10643 | + | |
| 10644 | + | |
| 10645 | + | |
| 10646 | + | |
| 10647 | + | |
| 10648 | + | |
| 10649 | + | |
| 10650 | + | |
| 10651 | + | |
| 10652 | + | |
| 10653 | + | |
| 10654 | + | |
| 10655 | + | |
| 10656 | + | |
| 10657 | + | |
| 10658 | + | |
| 10659 | + | |
| 10660 | + | |
| 10661 | + | |
| 10662 | + | |
| 10663 | + | |
| 10664 | + | |
| 10665 | + | |
| 10666 | + | |
| 10667 | + | |
| 10668 | + | |
| 10669 | + | |
| 10670 | + | |
| 10671 | + | |
| 10672 | + | |
| 10673 | + | |
| 10674 | + | |
| 10675 | + | |
| 10676 | + | |
| 10677 | + | |
| 10678 | + | |
| 10679 | + | |
| 10680 | + | |
| 10681 | + | |
| 10682 | + | |
| 10683 | + | |
| 10684 | + | |
| 10685 | + | |
| 10686 | + | |
| 10687 | + | |
| 10688 | + | |
| 10689 | + | |
| 10690 | + | |
| 10691 | + | |
| 10692 | + | |
| 10693 | + | |
| 10694 | + | |
| 10695 | + | |
| 10696 | + | |
| 10697 | + | |
| 10698 | + | |
| 10699 | + | |
| 10700 | + | |
| 10701 | + | |
| 10702 | + | |
| 10703 | + | |
| 10704 | + | |
| 10705 | + | |
| 10706 | + | |
| 10707 | + | |
| 10708 | + | |
| 10709 | + | |
10632 | 10710 | | |
10633 | 10711 | | |
10634 | 10712 | | |
| |||
12546 | 12624 | | |
12547 | 12625 | | |
12548 | 12626 | | |
| 12627 | + | |
12549 | 12628 | | |
12550 | 12629 | | |
12551 | 12630 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
587 | 587 | | |
588 | 588 | | |
589 | 589 | | |
| 590 | + | |
590 | 591 | | |
591 | 592 | | |
592 | 593 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| 42 | + | |
| 43 | + | |
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| |||
386 | 388 | | |
387 | 389 | | |
388 | 390 | | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
389 | 397 | | |
390 | 398 | | |
391 | 399 | | |
| |||
3087 | 3095 | | |
3088 | 3096 | | |
3089 | 3097 | | |
| 3098 | + | |
3090 | 3099 | | |
3091 | 3100 | | |
3092 | 3101 | | |
| |||
3982 | 3991 | | |
3983 | 3992 | | |
3984 | 3993 | | |
| 3994 | + | |
| 3995 | + | |
| 3996 | + | |
| 3997 | + | |
| 3998 | + | |
| 3999 | + | |
3985 | 4000 | | |
3986 | 4001 | | |
3987 | 4002 | | |
| |||
Lines changed: 16 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2174 | 2174 | | |
2175 | 2175 | | |
2176 | 2176 | | |
| 2177 | + | |
| 2178 | + | |
| 2179 | + | |
| 2180 | + | |
| 2181 | + | |
| 2182 | + | |
| 2183 | + | |
| 2184 | + | |
2177 | 2185 | | |
2178 | 2186 | | |
2179 | 2187 | | |
| |||
4552 | 4560 | | |
4553 | 4561 | | |
4554 | 4562 | | |
| 4563 | + | |
| 4564 | + | |
| 4565 | + | |
| 4566 | + | |
| 4567 | + | |
| 4568 | + | |
| 4569 | + | |
| 4570 | + | |
4555 | 4571 | | |
4556 | 4572 | | |
4557 | 4573 | | |
| |||
Lines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
764 | 764 | | |
765 | 765 | | |
766 | 766 | | |
| 767 | + | |
767 | 768 | | |
768 | 769 | | |
769 | 770 | | |
| |||
Lines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
62 | 62 | | |
63 | 63 | | |
64 | 64 | | |
| 65 | + | |
0 commit comments