diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..2c8160f38 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -31,6 +31,7 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); +int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint32_t *iD, uint32_t *jD, double *D, uint64_t nD, uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG, double *alpha, double *beta, double *cost, uint64_t maxIter); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..3634fa04c 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -216,3 +216,82 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + + + +int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, + uint32_t *iD, uint32_t *jD, double *D, uint64_t nD, + uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG, + double *alpha, double *beta, double *cost, uint64_t maxIter) + { + // beware M and C are stored in row major C style!!! + + using namespace lemon; + uint64_t n, m, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + n = n1; + m = n2; + + + std::vector weights2(m); + Digraph di(n, m); + NetworkSimplexSimple net(di, true, n + m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + // Demand is actually negative supply... + + for (uint64_t i = 0; i < n2; i++) + { + double val = *(Y + i); + if (val > 0) + { + weights2[i] = -val; + } + } + + // Define the graph + net.supplyMap(X, n, &weights2[0], m); + + // Set the cost of each edge + for (uint64_t k = 0; k < nD; k++) + { + int i = iD[k]; + int j = jD[k]; + net.setCost(di.arcFromId(i * m + j), D[k]); + } + + // Solve the problem with the network simplex algorithm + + int ret = net.run(); + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) + { + *cost = net.totalCost(); + Arc a; + di.first(a); + cur = 0; + for (; a != INVALID; di.next(a)) + { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + if (flow > 0) + { + + *(G + cur) = flow; + *(iG + cur) = i; + *(jG + cur) = j - n; + *(alpha + i) = -net.potential(i); + *(beta + j - n) = net.potential(j); + cur++; + } + } + *nG = cur; // nb of value +1 for numpy indexing + } + + return ret; + +} diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..4933b5b15 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,7 +14,7 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, uint32_t import warnings @@ -23,6 +23,7 @@ cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint32_t *iD, uint32_t *jD, double *D, uint64_t nD, uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil def check_result(result_code): @@ -117,6 +118,83 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) return G, cost, alpha, beta, result_code +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[int, ndim=1, mode="c"] iM, np.ndarray[int, ndim=1, mode="c"] jM, np.ndarray[double, ndim=1, mode="c"] M, uint64_t max_iter): + """ + Solves the Earth Movers distance problem and returns the optimal transport matrix + + gamm=emd(a,b,M) + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the metric cost matrix + - a and b are the sample weights + + .. warning:: + Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + + Parameters + ---------- + a : (ns,) numpy.ndarray, float64 + source histogram + b : (nt,) numpy.ndarray, float64 + target histogram + iM : (n,) numpy.ndarray, uint32 + row indices of the non zero elements of the loss matrix (COO) + jM : (n,) numpy.ndarray, uint32 + column indices of the non zero elements of the loss matrix (COO) + M : (n,) numpy.ndarray, float64 + loss matrix (COO) + max_iter : uint64_t + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + + Returns + ------- + gamma: (ns x nt) numpy.ndarray + Optimal transportation matrix for the given parameters + + """ + cdef int n1= a.shape[0] + cdef int n2= b.shape[0] + cdef int nmax=n1+n2-1 + cdef int result_code = 0 + cdef uint64_t nG=0 + cdef uint64_t maxiter = max_iter + + cdef double cost=0 + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + + cdef np.ndarray[double, ndim=1, mode="c"] G=np.zeros(nmax) + cdef np.ndarray[uint32_t, ndim=1, mode="c"] iG=np.zeros(nmax, dtype=np.uint32) + cdef np.ndarray[uint32_t, ndim=1, mode="c"] jG=np.zeros(nmax, dtype=np.uint32) + + with nogil: + result_code = EMD_wrap_sparse(n1, n2, a.data, b.data, + iM.data, jM.data, M.data, iM.shape[0], + iG.data, jG.data, G.data, &nG, + alpha.data, beta.data, &cost, maxiter) + return G[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code + + + + @cython.boundscheck(False) @cython.wraparound(False)