Skip to content

Conversation

@nathanneike
Copy link

@nathanneike nathanneike commented Oct 28, 2025

Types of changes

  • New feature (non-breaking change which adds functionality)
  • Tests (new tests or changes to existing tests)

Motivation and context / Related issue

This PR implements a sparse EMD solver for memory-efficient optimal transport when the cost matrix has many infinite or forbidden edges (e.g., k-NN graphs, sparse networks).

Problem: The current dense EMD solver requires O(n²) memory for the full cost matrix, which becomes prohibitive for large-scale
problems even when most edges are forbidden.

Solution: This PR adds a sparse bipartite graph solver that only stores edges with finite costs, reducing memory usage from O(n²) to O(E) where E is the number of edges.

Use cases:

  • k-NN graph optimal transport
  • Large-scale sparse matching problems
  • Network flow with forbidden edges

How has this been tested

Unit Tests

Added two comprehensive tests in test/test_ot.py:

  • test_emd_sparse_vs_dense() - Verifies sparse and dense solvers produce identical transport matrices
  • test_emd2_sparse_vs_dense() - Verifies sparse and dense solvers produce identical costs

Both tests use the augmented k-NN approach:

  1. Create initial k-NN sparse graph
  2. Solve with dense solver to identify needed edges
  3. Augment graph with those edges
  4. Compare both solvers on identical graph structure

Test results: All 50 tests in test/test_ot.py pass

Verification

  • Costs match between solvers
  • Marginal constraints satisfied for both solvers
  • No regression in existing tests

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts). TODO: Add documentation
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file. TODO: Will add once ready for merge

TODO before [MRG]:

  • Add example script in examples/ folder demonstrating sparse solver usage
  • Add documentation explaining when to use sparse vs dense
  • Performance benchmarks comparing memory usage and runtime
  • Update RELEASES.md

Feedback requested:

  • Is the API design appropriate? (using sparse=True parameter)
  • Should we add more comprehensive tests?
  • Any concerns about the C++ implementation approach?

  - Implement sparse bipartite graph EMD solver in C++
  - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py)
  - Add unit tests to verify sparse and dense solvers produce identical results
  - Tests use augmented k-NN approach to ensure fair comparison
  - Update setup.py to include sparse solver compilation

  Both test_emd_sparse_vs_dense() and test_emd2_sparse_vs_dense() verify:
    * Identical costs between sparse and dense solvers
    * Marginal constraint satisfaction for both solvers
  This PR implements a sparse bipartite graph EMD solver for memory-efficient
  optimal transport when the cost matrix has many infinite or forbidden edges.

  Changes:
  - Implement sparse bipartite graph EMD solver in C++
  - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py)
  - Add unit tests to verify sparse and dense solvers produce identical results
  - Tests use augmented k-NN approach to ensure fair comparison

  Tests verify correctness:
    * test_emd_sparse_vs_dense() - verifies identical costs and marginal constraints
    * test_emd2_sparse_vs_dense() - verifies cost-only version

  Status: WIP - seeking feedback on implementation approach
  TODO: Add example script and documentation
@rflamary rflamary changed the title Sparse emd implementation [WIP] Sparse emd implementation Oct 28, 2025
…trix parameter from emd and fix linting issues
@codecov
Copy link

codecov bot commented Nov 2, 2025

Codecov Report

❌ Patch coverage is 94.15205% with 20 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.12%. Comparing base (d3867c6) to head (b184cd4).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #778      +/-   ##
==========================================
- Coverage   97.15%   97.12%   -0.04%     
==========================================
  Files         107      107              
  Lines       21906    22195     +289     
==========================================
+ Hits        21283    21556     +273     
- Misses        623      639      +16     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

nathanneike and others added 4 commits November 3, 2025 13:47
- Remove tuple format support for sparse matrices (use scipy.sparse only)
- Change index types from int64_t to uint64_t throughout (indices are never negative)
- Refactor emd() and emd2() with clear sparse/dense code path separation
- Add sparse_bipartitegraph.h to MANIFEST.in to fix build
- Add test_emd_sparse_backends() to verify backend compatibility
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @nathanneike for this PR.

I have many small comments but it is already looking very nice.

The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan.
.. image:: /_static/images/comparison.png
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us not add static image file. could you do a quick bench below and compare computational time?


# %%

X = np.array([[0, 0], [1, 0]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very simple example, maybe we can design a sparse example with more points?


# Solve sparse OT (intra-cluster only)
G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True)
cost_sparse = log_sparse["cost"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for log np.sum(G_sparse*M_sparse_large) shoudl work


# Dense OT
plt.subplot(1, 2, 1)
for i in range(nA):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a function for that, ot.plot.plot2D_samples_mat
https://pythonot.github.io/gen_modules/ot.plot.html#id1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function shoud be update to handle sparse OT matrices maybe


return None, log_dict
else:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return teh OT plan in sparse fomart (same type and device as M)

np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7)

# Reconstruct sparse matrix from flow for marginal checks
if G_sparse is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

erturn spasre matrix insteda of doing that

cols = []
data = []

for i in range(n_source):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

)

C_augmented_dense = np.full((n_source, n_target), large_cost)
C_augmented_array = C_augmented.toarray()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this alerady returs a dense matrix

b, G_sparse_reconstructed.sum(0), rtol=1e-5, atol=1e-7
)
else:
np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep just do that

cols_aug.append(j)
data_aug.append(C[i, j])

C_augmented = coo_matrix(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use nx.from_numpy here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants