@@ -244,6 +244,43 @@ function sinkhorn2(μ, ν, C, ε; kwargs...)
244244 return pot. sinkhorn2 (μ, ν, PyCall. PyReverseDims (permutedims (C)), ε; kwargs... )
245245end
246246
247+ """
248+ empirical_sinkhorn_divergence(xsource, xtarget, ε; kwargs...)
249+
250+ Compute the Sinkhorn divergence from empirical data, where `xsource` and `xtarget` are
251+ arrays representing samples in the source domain and target domain, respectively, and `ε`
252+ is the regularization term.
253+
254+ This function is a wrapper of the function
255+ [`ot.bregman.empirical_sinkhorn_divergence`](https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.empirical_sinkhorn_divergence)
256+ in the Python Optimal Transport package. Keyword arguments are listed in the documentation of the Python function.
257+
258+ # Examples
259+
260+ ```jldoctest
261+ julia> xsource = [1];
262+
263+ julia> xtarget = [2, 3];
264+
265+ julia> ε = 0.01;
266+
267+ julia> empirical_sinkhorn_divergence(xsource, xtarget, ε) ≈
268+ sinkhorn2([1], [0.5, 0.5], [1 4], ε) -
269+ (
270+ sinkhorn2([1], [1], zeros(1, 1), ε) +
271+ sinkhorn2([0.5, 0.5], [0.5, 0.5], [0 1; 1 0], ε)
272+ ) / 2
273+ true
274+ ```
275+
276+ See also: [`sinkhorn2`](@ref)
277+ """
278+ function empirical_sinkhorn_divergence (xsource, xtarget, ε; kwargs... )
279+ return pot. bregman. empirical_sinkhorn_divergence (
280+ reshape (xsource, Val (2 )), reshape (xtarget, Val (2 )), ε; kwargs...
281+ )
282+ end
283+
247284"""
248285 sinkhorn_unbalanced(μ, ν, C, ε, λ; kwargs...)
249286
0 commit comments