Commit 5614d80
authored
Add unified CEBRA encoder: pytorch implementation (#251)
* start tests
* remove print statements
* first passing test
* move functionality to base file in solver and separate in functions
* add test_select_model for multisession
* remove float16
* Improve modularity remove duplicate code and todos
* Add tests to solver
* Fix save/load
* Fix extra docs errors
* Add review updates
* apply ruff auto-fixes
* fix linting errors
* Run isort, ruff, yapf
* Fix gaussian mixture dataset import
* Fix all tests but xcebra tests
* Fix pytorch API usage example
* Make xCEBRA compatible with the batched inference & padding in solver
* Add some tests on transform() with xCEBRA
* Add some docstrings and typings and clean unnecessary changes
* Implement review comments
* Fix sklearn test
* Initial pass at integrating unifiedCEBRA
* Add name in NOTE
* Implement reviews on tests and typing
* Fix import errors
* Add select_model to aux solvers
* Fix tests
* Add mask tests
* Fix docs error
* Remove masking init()
* Remove shuffled neurons in unified dataset
* Remove extra datasets
* Add tests on the private functions in base solver
* Update tests and duplicate code based on review
* Fix quantized_embedding_norm undefined when `normalize=False` (#249)
* Fix tests
* Adapt unified code to get_model method
* Update mask.py
add headers to new files
* Update masking.py
- header
* Update test_data_masking.py
- header
* Implement review comments and fix typos
* Fix docs errors
* Remove np.int typing error
* Fix docstring warning
* Fix indentation docstrings
* Implement review comments
* Fix circular import and abstract method
* Add maskedmixin to __all__
* Implement extra review comments
* Change masking kwargs as tuple and not dict in sklearn impl
* Add integrations/decoders.py
* Fix typo
* minor simplification in solver
---------
Note, some comments in this PR overlap with
#168
and
#225
which were developed in parallel.1 parent 7ae5e1e commit 5614d80
File tree
22 files changed
+1951
-206
lines changed- cebra
- datasets
- data
- distributions
- integrations
- sklearn
- models
- solver
- docs/source/api/pytorch
- tests
22 files changed
+1951
-206
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
| 54 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| 30 | + | |
30 | 31 | | |
31 | 32 | | |
32 | 33 | | |
| |||
36 | 37 | | |
37 | 38 | | |
38 | 39 | | |
39 | | - | |
| 40 | + | |
40 | 41 | | |
41 | 42 | | |
42 | 43 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| 31 | + | |
31 | 32 | | |
| 33 | + | |
32 | 34 | | |
33 | 35 | | |
34 | 36 | | |
| |||
304 | 306 | | |
305 | 307 | | |
306 | 308 | | |
307 | | - | |
| 309 | + | |
308 | 310 | | |
309 | 311 | | |
310 | 312 | | |
| |||
435 | 437 | | |
436 | 438 | | |
437 | 439 | | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
0 commit comments