diff --git a/experiments/meg/sample.py b/experiments/meg/sample.py index 679fb87..d7a3cda 100644 --- a/experiments/meg/sample.py +++ b/experiments/meg/sample.py @@ -2,7 +2,8 @@ import numpy as np import copy import torch -from fadin.solver import FaDIn, plot +from fadin.solver import FaDIn +from fadin.utils.vis import plot from fadin.utils.utils_meg import proprocess_tasks, filter_activation, \ get_atoms_timestamps diff --git a/fadin/init.py b/fadin/init.py index 5d1c218..0984260 100644 --- a/fadin/init.py +++ b/fadin/init.py @@ -59,9 +59,9 @@ def init_hawkes_params(solver, init, events, n_ground_events, end_time): kernel_params_init = init['kernel'] # Format initial parameters for optimization - baseline = (baseline * solver.baseline_mask).requires_grad_(True) - alpha = (alpha * solver.alpha_mask).requires_grad_(True) - params_intens = [baseline, alpha] + solver.baseline = (baseline * solver.baseline_mask).requires_grad_(True) + solver.alpha = (alpha * solver.alpha_mask).requires_grad_(True) + params_intens = [solver.baseline, solver.alpha] solver.n_kernel_params = len(kernel_params_init) for i in range(solver.n_kernel_params): kernel_param = kernel_params_init[i].float().clip(1e-4) diff --git a/fadin/solver.py b/fadin/solver.py index 8ede96e..687f387 100644 --- a/fadin/solver.py +++ b/fadin/solver.py @@ -1,7 +1,5 @@ import torch import time -import matplotlib.pyplot as plt -import numpy as np from fadin.utils.utils import optimizer, projected_grid from fadin.utils.compute_constants import compute_constants_fadin @@ -316,130 +314,3 @@ def fit(self, events, end_time): print('iterations in ', time.time() - start) return self - - -def plot(solver, plotfig=False, bl_noise=False, title=None, ch_names=None, - savefig=None): - """ - Plots estimated kernels and baselines of solver. - Should be called after calling the `fit` method on solver. - - Parameters - ---------- - solver: |`MarkedFaDin` or `FaDIn` solver. - `fit` method should be called on the solver before calling `plot`. - - plotfig: bool (default `False`) - If set to `True`, the figure is plotted. - - bl_noise: bool (default`False`) - Whether to plot the baseline of noisy activations. - Only works if the solver has 'baseline_noise' attribute. - - title: `str` or `None`, default=`None` - Title of the plot. If set to `None`, the title text is generic. - - ch_names: list of `str` (default `None`) - Channel names for subplots. If set to `None`, will be set to - `np.arange(solver.n_dim).astype('str')`. - savefig: str or `None`, default=`None` - Path for saving the figure. If set to `None`, the figure is not saved. - - Returns - ------- - fig, axs : matplotlib.pyplot Figure - n_dim x n_dim subplots, where subplot of coordinates (i, j) shows the - kernel component $\\alpha_{i, j}\\phi_{i, j}$ and the baseline $\\mu_i$ - of the intensity function $\\lambda_i$. - - """ - # Recover kernel time values and y values for kernel plot - discretization = torch.linspace(0, solver.kernel_length, 200) - kernel = DiscreteKernelFiniteSupport(solver.delta, - solver.n_dim, - kernel=solver.kernel, - kernel_length=solver.kernel_length) - - kappa_values = kernel.kernel_eval(solver.params_intens[-2:], - discretization).detach() - # Plot - if ch_names is None: - ch_names = np.arange(solver.n_dim).astype('str') - fig, axs = plt.subplots(nrows=solver.n_dim, - ncols=solver.n_dim, - figsize=(4 * solver.n_dim, 4 * solver.n_dim), - sharey=True, - sharex=True, - squeeze=False) - for i in range(solver.n_dim): - for j in range(solver.n_dim): - # Plot baseline - label = (rf'$\mu_{{{ch_names[i]}}}$=' + - f'{round(solver.baseline[i].item(), 2)}') - axs[i, j].hlines( - y=solver.baseline[i].item(), - xmin=0, - xmax=solver.kernel_length, - label=label, - color='orange', - linewidth=4 - ) - if bl_noise: - # Plot noise baseline - mutilde = round(solver.baseline_noise[i].item(), 2) - label = rf'$\tilde{{\mu}}_{{{ch_names[i]}}}$={mutilde}' - axs[i, j].hlines( - y=solver.baseline_noise[i].item(), - xmin=0, - xmax=solver.kernel_length, - label=label, - color='green', - linewidth=4 - ) - # Plot kernel (i, j) - phi_values = solver.alpha[i, j].item() * kappa_values[i, j, 1:] - axs[i, j].plot( - discretization[1:], - phi_values, - label=rf'$\phi_{{{ch_names[i]},{ch_names[j]}}}$', - linewidth=4 - ) - if solver.kernel == 'truncated_gaussian': - # Plot mean of gaussian kernel - mean = round(solver.params_intens[-2][i, j].item(), 2) - axs[i, j].vlines( - x=mean, - ymin=0, - ymax=torch.max(phi_values).item(), - label=rf'mean={mean}', - color='pink', - linestyles='dashed', - linewidth=3, - ) - # Handle text - axs[i, j].set_xlabel('Time', size='x-large') - axs[i, j].tick_params( - axis='both', - which='major', - labelsize='x-large' - ) - axs[i, j].set_title( - f'{ch_names[j]}-> {ch_names[i]}', - size='x-large' - ) - axs[i, j].legend(fontsize='large', loc='best') - # Plot title - if title is None: - fig_title = 'Hawkes influence ' + solver.kernel + ' kernel' - else: - fig_title = title - fig.suptitle(fig_title, size=20) - fig.tight_layout() - # Save figure - if savefig is not None: - fig.savefig(savefig) - # Plot figure - if plotfig: - fig.show() - - return fig, axs diff --git a/fadin/utils/vis.py b/fadin/utils/vis.py new file mode 100644 index 0000000..07627bf --- /dev/null +++ b/fadin/utils/vis.py @@ -0,0 +1,132 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from fadin.kernels import DiscreteKernelFiniteSupport + + +def plot(solver, plotfig=False, bl_noise=False, title=None, ch_names=None, + savefig=None): + """ + Plots estimated kernels and baselines of `FaDIn` solver. + Should be called after calling the `fit` method on solver. + + Parameters + ---------- + solver: `FaDIn` solver. + `fit` method should be called on the solver before calling `plot`. + + plotfig: bool (default `False`) + If set to `True`, the figure is plotted. + + bl_noise: bool (default`False`) + Whether to plot the baseline of noisy activations. + Only works if the solver has 'baseline_noise' attribute. + + title: `str` or `None`, default=`None` + Title of the plot. If set to `None`, the title text is generic. + + ch_names: list of `str` (default `None`) + Channel names for subplots. If set to `None`, will be set to + `np.arange(solver.n_dim).astype('str')`. + savefig: str or `None`, default=`None` + Path for saving the figure. If set to `None`, the figure is not saved. + + Returns + ------- + fig, axs : matplotlib.pyplot Figure + n_dim x n_dim subplots, where subplot of coordinates (i, j) shows the + kernel component $\\alpha_{i, j}\\phi_{i, j}$ and the baseline $\\mu_i$ + of the intensity function $\\lambda_i$. + + """ + # Recover kernel time values and y values for kernel plot + discretization = torch.linspace(0, solver.kernel_length, 200) + kernel = DiscreteKernelFiniteSupport(solver.delta, + solver.n_dim, + kernel=solver.kernel, + kernel_length=solver.kernel_length) + + kappa_values = kernel.kernel_eval(solver.params_intens[-2:], + discretization).detach() + # Plot + if ch_names is None: + ch_names = np.arange(solver.n_dim).astype('str') + fig, axs = plt.subplots(nrows=solver.n_dim, + ncols=solver.n_dim, + figsize=(4 * solver.n_dim, 4 * solver.n_dim), + sharey=True, + sharex=True, + squeeze=False) + for i in range(solver.n_dim): + for j in range(solver.n_dim): + # Plot baseline + label = (rf'$\mu_{{{ch_names[i]}}}$=' + + f'{round(solver.baseline[i].item(), 2)}') + axs[i, j].hlines( + y=solver.baseline[i].item(), + xmin=0, + xmax=solver.kernel_length, + label=label, + color='orange', + linewidth=4 + ) + if bl_noise: + # Plot noise baseline + mutilde = round(solver.baseline_noise[i].item(), 2) + label = rf'$\tilde{{\mu}}_{{{ch_names[i]}}}$={mutilde}' + axs[i, j].hlines( + y=solver.baseline_noise[i].item(), + xmin=0, + xmax=solver.kernel_length, + label=label, + color='green', + linewidth=4 + ) + # Plot kernel (i, j) + phi_values = solver.alpha[i, j].item() * kappa_values[i, j, 1:] + axs[i, j].plot( + discretization[1:], + phi_values, + label=rf'$\phi_{{{ch_names[i]},{ch_names[j]}}}$', + linewidth=4 + ) + if solver.kernel == 'truncated_gaussian': + # Plot mean of gaussian kernel + mean = round(solver.params_intens[-2][i, j].item(), 2) + axs[i, j].vlines( + x=mean, + ymin=0, + ymax=torch.max(phi_values).item(), + label=rf'mean={mean}', + color='pink', + linestyles='dashed', + linewidth=3, + ) + # Handle text + axs[i, j].set_xlabel('Time', size='x-large') + axs[i, j].tick_params( + axis='both', + which='major', + labelsize='x-large' + ) + axs[i, j].set_title( + f'{ch_names[j]}-> {ch_names[i]}', + size='x-large' + ) + axs[i, j].legend(fontsize='large', loc='best') + # Plot title + if title is None: + fig_title = 'Hawkes influence ' + solver.kernel + ' kernel' + else: + fig_title = title + fig.suptitle(fig_title, size=20) + fig.tight_layout() + # Save figure + if savefig is not None: + fig.savefig(savefig) + # Plot figure + if plotfig: + fig.show() + + return fig, axs