Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sparse implementation of EMD #683

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum ProblemType {

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint32_t *iD, uint32_t *jD, double *D, uint64_t nD, uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG, double *alpha, double *beta, double *cost, uint64_t maxIter);



Expand Down
79 changes: 79 additions & 0 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,82 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,

return ret;
}



int EMD_wrap_sparse(int n1, int n2, double *X, double *Y,
uint32_t *iD, uint32_t *jD, double *D, uint64_t nD,
uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG,
double *alpha, double *beta, double *cost, uint64_t maxIter)
{
// beware M and C are stored in row major C style!!!

using namespace lemon;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

n = n1;
m = n2;


std::vector<double> weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph, double, double, node_id_type> net(di, true, n + m, n*m, maxIter);

// Set supply and demand, don't account for 0 values (faster)

// Demand is actually negative supply...

for (uint64_t i = 0; i < n2; i++)
{
double val = *(Y + i);
if (val > 0)
{
weights2[i] = -val;
}
}

// Define the graph
net.supplyMap(X, n, &weights2[0], m);

// Set the cost of each edge
for (uint64_t k = 0; k < nD; k++)
{
int i = iD[k];
int j = jD[k];
net.setCost(di.arcFromId(i * m + j), D[k]);
}

// Solve the problem with the network simplex algorithm

int ret = net.run();
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED)
{
*cost = net.totalCost();
Arc a;
di.first(a);
cur = 0;
for (; a != INVALID; di.next(a))
{
int i = di.source(a);
int j = di.target(a);
double flow = net.flow(a);
if (flow > 0)
{

*(G + cur) = flow;
*(iG + cur) = i;
*(jG + cur) = j - n;
*(alpha + i) = -net.potential(i);
*(beta + j - n) = net.potential(j);
cur++;
}
}
*nG = cur; // nb of value +1 for numpy indexing
}

return ret;

}
80 changes: 79 additions & 1 deletion ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from ..utils import dist

cimport cython
cimport libc.math as math
from libc.stdint cimport uint64_t
from libc.stdint cimport uint64_t, uint32_t

import warnings

Expand All @@ -23,6 +23,7 @@ cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint32_t *iD, uint32_t *jD, double *D, uint64_t nD, uint32_t *iG, uint32_t *jG, double *G, uint64_t *nG, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil


def check_result(result_code):
Expand Down Expand Up @@ -117,6 +118,83 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code

@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[int, ndim=1, mode="c"] iM, np.ndarray[int, ndim=1, mode="c"] jM, np.ndarray[double, ndim=1, mode="c"] M, uint64_t max_iter):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix

gamm=emd(a,b,M)

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma\geq 0
where :

- M is the metric cost matrix
- a and b are the sample weights

.. warning::
Note that the M matrix needs to be a C-order :py.cls:`numpy.array`

.. warning::
The C++ solver discards all samples in the distributions with
zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
on the samples with weights different from zero.

Parameters
----------
a : (ns,) numpy.ndarray, float64
source histogram
b : (nt,) numpy.ndarray, float64
target histogram
iM : (n,) numpy.ndarray, uint32
row indices of the non zero elements of the loss matrix (COO)
jM : (n,) numpy.ndarray, uint32
column indices of the non zero elements of the loss matrix (COO)
M : (n,) numpy.ndarray, float64
loss matrix (COO)
max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.

Returns
-------
gamma: (ns x nt) numpy.ndarray
Optimal transportation matrix for the given parameters

"""
cdef int n1= a.shape[0]
cdef int n2= b.shape[0]
cdef int nmax=n1+n2-1
cdef int result_code = 0
cdef uint64_t nG=0
cdef uint64_t maxiter = max_iter

cdef double cost=0
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)

cdef np.ndarray[double, ndim=1, mode="c"] G=np.zeros(nmax)
cdef np.ndarray[uint32_t, ndim=1, mode="c"] iG=np.zeros(nmax, dtype=np.uint32)
cdef np.ndarray[uint32_t, ndim=1, mode="c"] jG=np.zeros(nmax, dtype=np.uint32)

with nogil:
result_code = EMD_wrap_sparse(n1, n2, <double*> a.data, <double*> b.data,
<uint32_t*> iM.data, <uint32_t*> jM.data, <double*> M.data, iM.shape[0],
<uint32_t*> iG.data, <uint32_t*> jG.data, <double*> G.data, &nG,
<double*> alpha.data, <double*> beta.data, <double*> &cost, maxiter)
return G[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code





@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down
Loading