Skip to content

Commit

Permalink
move plot to separate utils file
Browse files Browse the repository at this point in the history
  • Loading branch information
vloison committed Jul 4, 2024
1 parent a212ab1 commit a7c924e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 133 deletions.
3 changes: 2 additions & 1 deletion experiments/meg/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import copy
import torch
from fadin.solver import FaDIn, plot
from fadin.solver import FaDIn
from fadin.utils.vis import plot
from fadin.utils.utils_meg import proprocess_tasks, filter_activation, \
get_atoms_timestamps

Expand Down
6 changes: 3 additions & 3 deletions fadin/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def init_hawkes_params(solver, init, events, n_ground_events, end_time):
kernel_params_init = init['kernel']

# Format initial parameters for optimization
baseline = (baseline * solver.baseline_mask).requires_grad_(True)
alpha = (alpha * solver.alpha_mask).requires_grad_(True)
params_intens = [baseline, alpha]
solver.baseline = (baseline * solver.baseline_mask).requires_grad_(True)
solver.alpha = (alpha * solver.alpha_mask).requires_grad_(True)
params_intens = [solver.baseline, solver.alpha]
solver.n_kernel_params = len(kernel_params_init)
for i in range(solver.n_kernel_params):
kernel_param = kernel_params_init[i].float().clip(1e-4)
Expand Down
129 changes: 0 additions & 129 deletions fadin/solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
import time
import matplotlib.pyplot as plt
import numpy as np

from fadin.utils.utils import optimizer, projected_grid
from fadin.utils.compute_constants import compute_constants_fadin
Expand Down Expand Up @@ -316,130 +314,3 @@ def fit(self, events, end_time):
print('iterations in ', time.time() - start)

return self


def plot(solver, plotfig=False, bl_noise=False, title=None, ch_names=None,
savefig=None):
"""
Plots estimated kernels and baselines of solver.
Should be called after calling the `fit` method on solver.
Parameters
----------
solver: |`MarkedFaDin` or `FaDIn` solver.
`fit` method should be called on the solver before calling `plot`.
plotfig: bool (default `False`)
If set to `True`, the figure is plotted.
bl_noise: bool (default`False`)
Whether to plot the baseline of noisy activations.
Only works if the solver has 'baseline_noise' attribute.
title: `str` or `None`, default=`None`
Title of the plot. If set to `None`, the title text is generic.
ch_names: list of `str` (default `None`)
Channel names for subplots. If set to `None`, will be set to
`np.arange(solver.n_dim).astype('str')`.
savefig: str or `None`, default=`None`
Path for saving the figure. If set to `None`, the figure is not saved.
Returns
-------
fig, axs : matplotlib.pyplot Figure
n_dim x n_dim subplots, where subplot of coordinates (i, j) shows the
kernel component $\\alpha_{i, j}\\phi_{i, j}$ and the baseline $\\mu_i$
of the intensity function $\\lambda_i$.
"""
# Recover kernel time values and y values for kernel plot
discretization = torch.linspace(0, solver.kernel_length, 200)
kernel = DiscreteKernelFiniteSupport(solver.delta,
solver.n_dim,
kernel=solver.kernel,
kernel_length=solver.kernel_length)

kappa_values = kernel.kernel_eval(solver.params_intens[-2:],
discretization).detach()
# Plot
if ch_names is None:
ch_names = np.arange(solver.n_dim).astype('str')
fig, axs = plt.subplots(nrows=solver.n_dim,
ncols=solver.n_dim,
figsize=(4 * solver.n_dim, 4 * solver.n_dim),
sharey=True,
sharex=True,
squeeze=False)
for i in range(solver.n_dim):
for j in range(solver.n_dim):
# Plot baseline
label = (rf'$\mu_{{{ch_names[i]}}}$=' +
f'{round(solver.baseline[i].item(), 2)}')
axs[i, j].hlines(
y=solver.baseline[i].item(),
xmin=0,
xmax=solver.kernel_length,
label=label,
color='orange',
linewidth=4
)
if bl_noise:
# Plot noise baseline
mutilde = round(solver.baseline_noise[i].item(), 2)
label = rf'$\tilde{{\mu}}_{{{ch_names[i]}}}$={mutilde}'
axs[i, j].hlines(
y=solver.baseline_noise[i].item(),
xmin=0,
xmax=solver.kernel_length,
label=label,
color='green',
linewidth=4
)
# Plot kernel (i, j)
phi_values = solver.alpha[i, j].item() * kappa_values[i, j, 1:]
axs[i, j].plot(
discretization[1:],
phi_values,
label=rf'$\phi_{{{ch_names[i]},{ch_names[j]}}}$',
linewidth=4
)
if solver.kernel == 'truncated_gaussian':
# Plot mean of gaussian kernel
mean = round(solver.params_intens[-2][i, j].item(), 2)
axs[i, j].vlines(
x=mean,
ymin=0,
ymax=torch.max(phi_values).item(),
label=rf'mean={mean}',
color='pink',
linestyles='dashed',
linewidth=3,
)
# Handle text
axs[i, j].set_xlabel('Time', size='x-large')
axs[i, j].tick_params(
axis='both',
which='major',
labelsize='x-large'
)
axs[i, j].set_title(
f'{ch_names[j]}-> {ch_names[i]}',
size='x-large'
)
axs[i, j].legend(fontsize='large', loc='best')
# Plot title
if title is None:
fig_title = 'Hawkes influence ' + solver.kernel + ' kernel'
else:
fig_title = title
fig.suptitle(fig_title, size=20)
fig.tight_layout()
# Save figure
if savefig is not None:
fig.savefig(savefig)
# Plot figure
if plotfig:
fig.show()

return fig, axs
132 changes: 132 additions & 0 deletions fadin/utils/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
import numpy as np
import matplotlib.pyplot as plt

from fadin.kernels import DiscreteKernelFiniteSupport


def plot(solver, plotfig=False, bl_noise=False, title=None, ch_names=None,
savefig=None):
"""
Plots estimated kernels and baselines of `FaDIn` solver.
Should be called after calling the `fit` method on solver.
Parameters
----------
solver: `FaDIn` solver.
`fit` method should be called on the solver before calling `plot`.
plotfig: bool (default `False`)
If set to `True`, the figure is plotted.
bl_noise: bool (default`False`)
Whether to plot the baseline of noisy activations.
Only works if the solver has 'baseline_noise' attribute.
title: `str` or `None`, default=`None`
Title of the plot. If set to `None`, the title text is generic.
ch_names: list of `str` (default `None`)
Channel names for subplots. If set to `None`, will be set to
`np.arange(solver.n_dim).astype('str')`.
savefig: str or `None`, default=`None`
Path for saving the figure. If set to `None`, the figure is not saved.
Returns
-------
fig, axs : matplotlib.pyplot Figure
n_dim x n_dim subplots, where subplot of coordinates (i, j) shows the
kernel component $\\alpha_{i, j}\\phi_{i, j}$ and the baseline $\\mu_i$
of the intensity function $\\lambda_i$.
"""
# Recover kernel time values and y values for kernel plot
discretization = torch.linspace(0, solver.kernel_length, 200)
kernel = DiscreteKernelFiniteSupport(solver.delta,
solver.n_dim,
kernel=solver.kernel,
kernel_length=solver.kernel_length)

kappa_values = kernel.kernel_eval(solver.params_intens[-2:],
discretization).detach()
# Plot
if ch_names is None:
ch_names = np.arange(solver.n_dim).astype('str')
fig, axs = plt.subplots(nrows=solver.n_dim,
ncols=solver.n_dim,
figsize=(4 * solver.n_dim, 4 * solver.n_dim),
sharey=True,
sharex=True,
squeeze=False)
for i in range(solver.n_dim):
for j in range(solver.n_dim):
# Plot baseline
label = (rf'$\mu_{{{ch_names[i]}}}$=' +
f'{round(solver.baseline[i].item(), 2)}')
axs[i, j].hlines(
y=solver.baseline[i].item(),
xmin=0,
xmax=solver.kernel_length,
label=label,
color='orange',
linewidth=4
)
if bl_noise:
# Plot noise baseline
mutilde = round(solver.baseline_noise[i].item(), 2)
label = rf'$\tilde{{\mu}}_{{{ch_names[i]}}}$={mutilde}'
axs[i, j].hlines(
y=solver.baseline_noise[i].item(),
xmin=0,
xmax=solver.kernel_length,
label=label,
color='green',
linewidth=4
)
# Plot kernel (i, j)
phi_values = solver.alpha[i, j].item() * kappa_values[i, j, 1:]
axs[i, j].plot(
discretization[1:],
phi_values,
label=rf'$\phi_{{{ch_names[i]},{ch_names[j]}}}$',
linewidth=4
)
if solver.kernel == 'truncated_gaussian':
# Plot mean of gaussian kernel
mean = round(solver.params_intens[-2][i, j].item(), 2)
axs[i, j].vlines(
x=mean,
ymin=0,
ymax=torch.max(phi_values).item(),
label=rf'mean={mean}',
color='pink',
linestyles='dashed',
linewidth=3,
)
# Handle text
axs[i, j].set_xlabel('Time', size='x-large')
axs[i, j].tick_params(
axis='both',
which='major',
labelsize='x-large'
)
axs[i, j].set_title(
f'{ch_names[j]}-> {ch_names[i]}',
size='x-large'
)
axs[i, j].legend(fontsize='large', loc='best')
# Plot title
if title is None:
fig_title = 'Hawkes influence ' + solver.kernel + ' kernel'
else:
fig_title = title
fig.suptitle(fig_title, size=20)
fig.tight_layout()
# Save figure
if savefig is not None:
fig.savefig(savefig)
# Plot figure
if plotfig:
fig.show()

return fig, axs

0 comments on commit a7c924e

Please sign in to comment.