diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 865c1e71a..b51b4e1ff 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -91,7 +91,7 @@ def build_noisy_circular_graph( g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) else: @@ -107,7 +107,7 @@ def build_noisy_circular_graph( if i == N - 1: g.add_edge(i, 1) g.add_edge(N, 0) - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) else: @@ -157,7 +157,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): plt.subplot(3, 3, i + 1) g = X0[i] pos = nx.kamada_kawai_layout(g) - nx.draw( + nx.draw_networkx( g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), @@ -173,7 +173,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances -Cs = [shortest_path(nx.adjacency_matrix(x).todense()) for x in X0] +Cs = [shortest_path(nx.adjacency_matrix(x).toarray()) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] Ys = [ np.array([v for (k, v) in nx.get_node_attributes(x, "attr_name").items()]).reshape( @@ -199,7 +199,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% pos = nx.kamada_kawai_layout(bary) -nx.draw( +nx.draw_networkx( bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False ) plt.suptitle("Barycenter", fontsize=20) diff --git a/examples/gromov/plot_partial_fgw.py b/examples/gromov/plot_partial_fgw.py new file mode 100644 index 000000000..87489ee46 --- /dev/null +++ b/examples/gromov/plot_partial_fgw.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot partial FGW for subgraph matching +================================= + +This example illustrates the computation of partial (Fused) Gromov-Wasserstein +divergences for subgraph matching tasks [18, 29]. + +[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# %% load libraries +import numpy as np +import pylab as pl +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import ( + partial_gromov_wasserstein, + entropic_partial_gromov_wasserstein, + partial_fused_gromov_wasserstein, + entropic_partial_fused_gromov_wasserstein, +) +from ot import unif, dist +# %% Graph generation and visualization functions + + +def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0): + """Create a noisy circular graph""" + # create clean circle + np.random.seed(random_seed) + g = nx.Graph() + g.add_nodes_from(np.arange(n_clean + n_noise)) + for i in range(n_clean): + g.add_node(i, weight=math.sin(2 * i * math.pi / n_clean)) + if i == (n_clean - 1): + g.add_edge(i, 0) + else: + g.add_edge(i, i + 1) + # add nodes out of the circle as structure noise + if n_noise > 0: + noisy_nodes = np.random.choice(np.arange(n_clean), n_noise) + for i, j in enumerate(noisy_nodes): + g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean)) + g.add_edge(i + n_clean, j) + return g + + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap="viridis") + cpick.set_array([]) + val_map = {} + for k, v in nx.get_node_attributes(nx_graph, "weight").items(): + val_map[k] = cpick.to_rgba(v) + colors = [] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + + +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, +): + if pos is None: + pos = nx.kamada_kawai_layout(G) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + nx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=nodes_size[node], + alpha=1, + node_color=node_color, + ) + return pos + + +def draw_transp_colored( + G1, + C1, + G2, + C2, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + color_features=False, +): + if color_features: + nodes_color_part1 = graph_colors(G1, vmin=-1, vmax=1) + nodes_color_part2 = graph_colors(G2, vmin=-1, vmax=1) + else: + nodes_color_part1 = C1.shape[0] * ["C0"] + nodes_color_part2 = C2.shape[0] * ["C0"] + + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + ) + T_max = T.max() + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.8, + alpha=0.5 * T[k1, k2] / T_max, + color=nodes_color_part1[k1], + ) + return pos1, pos2 + + +# %% +############################################################################## +# Generate and visualize data +# ------------- + +# We build a clean circular graph that will be matched to a noisy circular graph. + +clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0) + +noisy_graph = build_noisy_circular_graph(n_clean=15, n_noise=5) + +graphs = [clean_graph, noisy_graph] +list_pos = [] +pl.figure(figsize=(6, 3)) +for i in range(2): + pl.subplot(1, 2, i + 1) + g = graphs[i] + if i == 0: + pl.title("clean graph", fontsize=16) + else: + pl.title("noisy graph", fontsize=16) + pos = nx.kamada_kawai_layout(g) + list_pos.append(pos) + nx.draw_networkx( + g, + pos=pos, + node_color=graph_colors(g, vmin=-1, vmax=1), + with_labels=False, + node_size=100, + ) +pl.show() + +############################################################################## +# Partial (Entropic) Gromov-Wasserstein computation and visualization +# ---------------------- + +# Adjacency matrices are compared using both exact and entropic partial GW +# discarding for now node features +Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs] +ps = [unif(C.shape[0]) for C in Cs] + +# provide an informative initialization for visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], m=m, G0=G0, symmetric=True, log=True +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], reg=0.01, m=m, G0=G0, symmetric=True, log=True +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_gw_dist"], 3), + np.round(logent["partial_gw_dist"], 3), +] +list_dist_str = ["pGW", "pGW_e"] +pl.figure(2, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + ) + +pl.tight_layout() +pl.show() + +############################################################################## +# Partial (Entropic) Fused Gromov-Wasserstein computation and visualization +# ---------------------- + +# Add now node features compared using pairwise euclidean distance +# to illustrate partial FGW computation with trade-off parameter alpha=0.5 +Ys = [ + np.array([v for (k, v) in nx.get_node_attributes(G, "weight").items()]).reshape( + -1, 1 + ) + for G in graphs +] +M = dist(Ys[0], Ys[1]) +# provide an informative initialization for visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + reg=0.01, + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_fgw_dist"], 3), + np.round(logent["partial_fgw_dist"], 3), +] +list_dist_str = ["pFGW", "pFGW_e"] + +pl.figure(3, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + color_features=True, + ) + +pl.tight_layout() +pl.show() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 5ccc197d6..23a5f96a2 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -5,7 +5,10 @@ ================================================== This example is designed to show how to use the Partial (Gromov-)Wasserstein -distance computation in POT. +distance computation in POT [29]. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. """ # Author: Laetitia Chapel diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index fdfbba951..5a069fdaf 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -1173,7 +1173,7 @@ def entropic_partial_gromov_wasserstein( Returns ------- - :math: `gamma` : (dim_a, dim_b) ndarray + :math: `gamma` : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1461,7 +1461,7 @@ def entropic_partial_fused_gromov_wasserstein( The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) @@ -1530,7 +1530,7 @@ def entropic_partial_fused_gromov_wasserstein( Returns ------- - :math: `gamma` : (dim_a, dim_b) ndarray + :math: `gamma` : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1693,7 +1693,7 @@ def entropic_partial_fused_gromov_wasserstein2( The function solves the following optimization problem: .. math:: - PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) diff --git a/ot/solvers.py b/ot/solvers.py index 96794d9cd..decf6177e 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1002,8 +1002,17 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1074,7 +1083,8 @@ def solve_gromov( else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed" + "semirelaxed", + "partial", ]: # Balanced regularized OT if reg_type.lower() in ["entropy"] and ( M is None or alpha == 1 @@ -1232,8 +1242,17 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1262,8 +1281,17 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial FGW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: