Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions ot/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import scipy as sp
from scipy.stats import ortho_group, multivariate_normal
from .utils import check_random_state, deprecated


Expand Down Expand Up @@ -180,3 +181,97 @@ def make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **k
def get_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs):
"""Deprecated see make_data_classif"""
return make_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs)


def make_gauss_hd(
ns, nt, p=100, dim=5, m_diff=3, a=(10, 15), b=(3, 3), sub_the_same=False
):
"""Generation of source and target domains from Gaussian HD distributions

Parameters
----------
ns : int
number of samples (source)
nt : int
number of samples (target)
p : int
dimension of the ambient space the data live in
dim : (int,int) or int
the intrinsic dimensions of the source and target Gaussian HD distriutions. If a single int the intrinsic dimension is assumed to be the same
m_diff : float
the shift in the first coordinate of the means of the Gaussian HD distributions, i.e. ms_0 and mt_0, respectively (see code)
a : (float, float)
positive floating numbers corresponding to the isotropic variances in the principal subspace, for the source and target distributions, respectively. The same as \delta in :ref:`[1] <references-make_gauss-hd>`, Proposition 2.2
b : (float, float)
positive floating numbers corresponding to the isotropic variance outside the principal subspace for the source and target distributions, respectively.
sub_the_same : bool
should the source/target Gaussian HD distributions live in the same principal subspace?

Returns
-------
Xs : ndarray, shape (ns, p)
`ns` observations of size `p` (source)
Xt : ndarray, shape (nt, p)
`nt` observations of size `p` (destination)
pmts : list
a list containing the parameters of the Gaussian HD distributions

.. _references-make_gauss_hd:
References
----------

.. [1] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions")

"""
d = (dim, dim) if isinstance(dim, int) else dim
mu = np.zeros((2, p))
S = []
mu[1, 0] = m_diff
Q = [ortho_group.rvs(p) for _ in range(2)]

if sub_the_same:
Q[1] = Q[0]

S.append(
Q[0]
@ np.diag(np.hstack((np.full(d[0], a[0]), np.full(p - d[0], b[0]))))
@ Q[0].T
)
S.append(
Q[1]
@ np.diag(np.hstack((np.full(d[1], a[1]), np.full(p - d[1], b[1]))))
@ Q[1].T
)

Xs = multivariate_normal.rvs(mean=mu[0], cov=S[0], size=ns)
Xt = multivariate_normal.rvs(mean=mu[1], cov=S[1], size=ns)

ms = mu[0]
mt = mu[1]
ds = d[0]
dt = d[1]
sigma2_s = np.array(b[0])
sigma2_t = np.array(b[1])
ls = np.repeat(a[0], ds) - sigma2_s
lt = np.repeat(a[1], dt) - sigma2_t
Us = Q[0][:, :ds]
Ut = Q[1][:, :dt]
ds = np.array([ds])
dt = np.array([dt])

prmts = {
"ms": ms,
"mt": mt,
"sigma2_s": sigma2_s,
"sigma2_t": sigma2_t,
"ls": ls,
"lt": lt,
"Us": Us,
"Ut": Ut,
"ds": ds,
"dt": dt,
"Cs": S[0],
"Ct": S[1],
}

return Xs, Xt, prmts
Loading
Loading