Skip to content

Commit

Permalink
Start porting from cbcbeat to fenics-beat
Browse files Browse the repository at this point in the history
  • Loading branch information
finsberg committed Jan 11, 2024
1 parent d8dd39f commit eec07f7
Show file tree
Hide file tree
Showing 8 changed files with 4,785 additions and 1,801 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ packages = find:
install_requires =
ap-features
cardiac-geometries
cbcbeat
click
fenics-beat
fenics-pulse
h5py
matplotlib
Expand Down
228 changes: 110 additions & 118 deletions src/simcardems/ep_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

import json
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import NamedTuple
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING

import cbcbeat
import beat
import dolfin
import pulse

# import cbcbeat

try:
import ufl_legacy as ufl
except ImportError:
Expand All @@ -22,69 +26,109 @@

if TYPE_CHECKING:
from .models.em_model import BaseEMCoupling
from .models.cell_model import BaseCellModel

# from .models.cell_model import BaseCellModel

logger = utils.getLogger(__name__)


class CellModel(NamedTuple):
init_states: Dict[str, float] | Dict[int, Dict[str, float]]
parameters: Dict[str, float] | Dict[int, Dict[str, float]]
v_index: int | Dict[int, int]
fun: Callable | Dict[int, Callable]
markers = None

@property
def num_states(self):
return len(self.init_states)


def setup_cell_model(
cls,
model,
coupling: BaseEMCoupling,
cell_params=None,
cell_inits=None,
cell_init_file=None,
drug_factors_file=None,
popu_factors_file=None,
disease_state=Config.disease_state,
):
cell_params = handle_cell_params(
cell_params=cell_params,
disease_state=disease_state,
drug_factors_file=drug_factors_file,
popu_factors_file=popu_factors_file,
CellModel=cls,
)
) -> CellModel:
# cell_params = handle_cell_params(
# cell_params=cell_params,
# disease_state=disease_state,
# drug_factors_file=drug_factors_file,
# popu_factors_file=popu_factors_file,
# CellModel=cls,
# )
if cell_params is None:
parameters = model.init_parameter_values()
else:
parameters = cell_params

# cell_inits = handle_cell_inits(
# cell_inits=cell_inits,
# cell_init_file=cell_init_file,
# model=model
# )
if cell_inits is None:
init_states = model.init_state_values()
else:
init_states = cell_inits

cell_inits = handle_cell_inits(
cell_inits=cell_inits,
cell_init_file=cell_init_file,
CellModel=cls,
)
# return cls(
# init_conditions=cell_inits,
# params=cell_params,
# coupling=coupling,
# )

# model = beat.cellmodels.torord_dyn_chloride.mid

return cls(
init_conditions=cell_inits,
params=cell_params,
coupling=coupling,
return CellModel(
init_states=init_states,
parameters=parameters,
v_index=model.state_index("v"),
fun=model.forward_generalized_rush_larsen,
)


def setup_solver(
dt,
cellmodel: cbcbeat.CardiacCellModel,
cellmodel: CellModel,
coupling: BaseEMCoupling,
scheme=Config.ep_ode_scheme,
theta=Config.ep_theta,
preconditioner=Config.ep_preconditioner,
PCL=Config.PCL,
) -> cbcbeat.SplittingSolver:
) -> beat.MonodomainSplittingSolver:
# Set-up cardiac model
ps = setup_splitting_solver_parameters(
theta=theta,
preconditioner=preconditioner,
dt=dt,
scheme=scheme,
)
ep_heart = setup_model(
cellmodel,
# ps = setup_splitting_solver_parameters(
# theta=theta,
# preconditioner=preconditioner,
# dt=dt,
# scheme=scheme,
# )
pde = setup_model(
coupling.geometry.ep_mesh,
PCL=PCL,
microstructure=coupling.geometry.microstructure_ep,
stimulus_domain=coupling.geometry.stimulus_domain,
)
solver = cbcbeat.SplittingSolver(ep_heart, params=ps)

ode = beat.odesolver.DolfinODESolver(
pde.state,
num_states=cellmodel.num_states,
fun=cellmodel.fun,
init_states=cellmodel.init_states,
parameters=cellmodel.parameters,
v_index=cellmodel.v_index,
)
# solver = cbcbeat.SplittingSolver(ep_heart, params=ps)
solver = beat.MonodomainSplittingSolver(pde=pde, ode=ode)

# Extract the solution fields and set the initial conditions
(vs_, vs, vur) = solver.solution_fields()
vs_.assign(cellmodel.initial_conditions())
# (vs_, vs, vur) = solver.solution_fields()
# vs_.assign(cellmodel.initial_conditions())

coupling.register_ep_model(solver)
coupling.print_ep_info()
Expand Down Expand Up @@ -141,7 +185,6 @@ def harmonic_mean(a, b):


def setup_model(
cellmodel: cbcbeat.CardiacCellModel,
mesh: dolfin.Mesh,
microstructure: pulse.Microstructure,
stimulus_domain: geometry.StimulusDomain,
Expand All @@ -150,7 +193,7 @@ def setup_model(
C_m: float = 0.01,
duration: float = 2.0,
A: float = 50_000.0,
) -> cbcbeat.CardiacModel:
) -> beat.MonodomainModel:
"""Set-up cardiac model based on benchmark parameters
Expand Down Expand Up @@ -192,7 +235,7 @@ def setup_model(

s = "((std::fmod(time,PCL) >= start) & (std::fmod(time,PCL) <= duration + start)) ? amplitude : 0.0"

I_s = dolfin.Expression(
I_s_expr = dolfin.Expression(
s,
time=time,
start=0.0,
Expand All @@ -201,65 +244,14 @@ def setup_model(
PCL=PCL,
degree=0,
)
# Store input parameters in cardiac model
stimulus = cbcbeat.Markerwise(
(I_s,),
(stimulus_domain.marker,),
stimulus_domain.domain,
)

petsc_options = [
["ksp_type", "cg"],
["pc_type", "gamg"],
["pc_gamg_verbose", "10"],
["pc_gamg_square_graph", "0"],
["pc_gamg_coarse_eq_limit", "3000"],
["mg_coarse_pc_type", "redundant"],
["mg_coarse_sub_pc_type", "lu"],
["mg_levels_ksp_type", "richardson"],
["mg_levels_ksp_max_it", "3"],
["mg_levels_pc_type", "sor"],
]
for opt in petsc_options:
dolfin.PETScOptions.set(*opt)

heart = cbcbeat.CardiacModel(
domain=mesh,
time=time,
M_i=M,
M_e=None,
cell_models=cellmodel,
stimulus=stimulus,
applied_current=None,
dx = dolfin.Measure("dx", domain=mesh, subdomain_data=stimulus_domain.domain)(
stimulus_domain.marker,
)
I_s = beat.base_model.Stimulus(dz=dx, expr=I_s_expr)

return heart


def setup_splitting_solver_parameters(
dt,
theta=0.5,
preconditioner="sor",
scheme="GRL1",
):
ps = cbcbeat.SplittingSolver.default_parameters()
ps["pde_solver"] = "monodomain"
ps["MonodomainSolver"]["linear_solver_type"] = "iterative"
ps["MonodomainSolver"]["theta"] = theta
ps["MonodomainSolver"]["preconditioner"] = preconditioner
ps["MonodomainSolver"]["default_timestep"] = dt
ps["MonodomainSolver"]["use_custom_preconditioner"] = False
ps["theta"] = theta
ps["enable_adjoint"] = False
ps["apply_stimulus_current_to_pde"] = True
# ps["BasicCardiacODESolver"]["scheme"] = scheme
ps["CardiacODESolver"]["scheme"] = scheme
# ps["ode_solver_choice"] = "BasicCardiacODESolver"
# ps["BasicCardiacODESolver"]["V_polynomial_family"] = "CG"
# ps["BasicCardiacODESolver"]["V_polynomial_degree"] = 1
# ps["BasicCardiacODESolver"]["S_polynomial_family"] = "CG"
# ps["BasicCardiacODESolver"]["S_polynomial_degree"] = 1
return ps
params = {"preconditioner": "sor", "use_custom_preconditioner": False}
return beat.MonodomainModel(time=time, mesh=mesh, M=M, I_s=I_s, params=params)


def file_exist(filename: Optional[str], suffix: str) -> bool:
Expand Down Expand Up @@ -311,29 +303,29 @@ def handle_cell_params(
return cell_params_tmp


def handle_cell_inits(
CellModel: Type[cbcbeat.CardiacCellModel],
cell_inits: Optional[Dict[str, float]] = None,
cell_init_file: str = "",
) -> Dict[str, float]:
cell_inits_tmp = CellModel.default_initial_conditions()
if file_exist(cell_init_file, ".json"):
cell_inits_tmp.update(load_json(cell_init_file))

if file_exist(cell_init_file, ".h5"):
cell_inits_tmp.update(load_json(cell_init_file))
from .save_load_functions import load_initial_conditions_from_h5

cell_inits = load_initial_conditions_from_h5(
cell_init_file,
CellModel=CellModel,
)

# FIXME: This is a bit confusing, since it will overwrite the
# inputs from the cell_init_file. There should be only one way to
# do this IMO. I think this might be difficult for the user to reason
# about. I think in general we should handle the loading from files
# at higher level.
if cell_inits is not None:
cell_inits_tmp.update(cell_inits)
return cell_inits_tmp
# def handle_cell_inits(
# CellModel,
# cell_inits: Optional[Dict[str, float]] = None,
# cell_init_file: str = "",
# ) -> Dict[str, float]:
# cell_inits_tmp = CellModel.
# if file_exist(cell_init_file, ".json"):
# cell_inits_tmp.update(load_json(cell_init_file))

# if file_exist(cell_init_file, ".h5"):
# cell_inits_tmp.update(load_json(cell_init_file))
# from .save_load_functions import load_initial_conditions_from_h5

# cell_inits = load_initial_conditions_from_h5(
# cell_init_file,
# CellModel=CellModel,
# )

# # FIXME: This is a bit confusing, since it will overwrite the
# # inputs from the cell_init_file. There should be only one way to
# # do this IMO. I think this might be difficult for the user to reason
# # about. I think in general we should handle the loading from files
# # at higher level.
# if cell_inits is not None:
# cell_inits_tmp.update(cell_inits)
# return cell_inits_tmp
7 changes: 4 additions & 3 deletions src/simcardems/models/em_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from pathlib import Path

import cbcbeat
import beat
import dolfin
import pulse

Expand Down Expand Up @@ -40,7 +40,7 @@ def setup_EM_model(
coupling = cls_EMCoupling(geometry, **state_params)

cellmodel = ep_model.setup_cell_model(
cls=cls_CellModel,
model=cls_CellModel,
coupling=coupling,
cell_init_file=config.cell_init_file,
drug_factors_file=config.drug_factors_file,
Expand All @@ -57,6 +57,7 @@ def setup_EM_model(
PCL=config.PCL,
cellmodel=cellmodel,
)

coupling.register_ep_model(solver)

mech_heart = mechanics_model.setup_solver(
Expand Down Expand Up @@ -150,7 +151,7 @@ def dt_mechanics(self) -> float:
def coupling_type(self):
"CustomType"

def register_ep_model(self, solver: cbcbeat.SplittingSolver) -> None:
def register_ep_model(self, solver: beat.MonodomainSplittingSolver) -> None:
pass

def register_mech_model(self, solver: pulse.MechanicsProblem) -> None:
Expand Down
4 changes: 3 additions & 1 deletion src/simcardems/models/pureEP_ORdmm_Land/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from . import cell_model
from . import cell_model as CellModel
from . import em_model
from .cell_model import ORdmmLandPureEp as CellModel
from .em_model import EMCoupling

# from .cell_model import ORdmmLandPureEp as CellModel

ActiveModel = None
loggers = [
"simcardems.models.pure_ep_ORdmm_Land.cell_model.logger",
Expand Down
Loading

0 comments on commit eec07f7

Please sign in to comment.