Skip to content

Commit

Permalink
transform kernels and convolutions using a transformer before code ge…
Browse files Browse the repository at this point in the history
…neration
  • Loading branch information
C.A.P. Linssen committed Aug 27, 2024
1 parent 410e59a commit 4bd5b42
Show file tree
Hide file tree
Showing 13 changed files with 768 additions and 914 deletions.
1 change: 0 additions & 1 deletion pynestml/codegeneration/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
# Environment for neuron templates
env = Environment(loader=FileSystemLoader(_template_dirs))
env.globals["raise"] = self.raise_helper
env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel

# Load all the templates
_templates = list()
Expand Down
161 changes: 13 additions & 148 deletions pynestml/codegeneration/nest_code_generator.py

Large diffs are not rendered by default.

238 changes: 19 additions & 219 deletions pynestml/codegeneration/nest_compartmental_code_generator.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,8 @@ std::vector< std::tuple< int, int > > {{ neuronName }}::rport_to_nestml_buffer_i

// copy state struct S_
{%- for init in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
{%- endif %}
{%- endfor %}

// copy internals V_
Expand Down Expand Up @@ -723,28 +721,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
update_delay_variables();
{%- endif %}

/**
* subthreshold updates of the convolution variables
*
* step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions
**/

{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{%- if use_gap_junctions %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }};
{%- else %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }};
{%- endif %}
{%- endif %}
{%- endfor %}
{%- endif %}


/**
* Begin NESTML generated code for the update block(s)
**/
Expand Down Expand Up @@ -774,30 +750,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
}
{%- endfor %}

/**
* subthreshold updates of the convolution variables
*
* step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep.
**/
{% if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{{ printer.print(var_ast) }} = {{variable_name}}__tmp_;
{%- endif %}
{%- endfor %}
{%- endif %}


/**
* spike updates due to convolutions
**/
{% filter indent(4) %}
{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
{%- endfilter %}

/**
* Begin NESTML generated code for the onCondition block(s)
**/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,8 @@ public:

{% filter indent(2, True) -%}
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
{% endif %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
{% endfor %}
{%- endfilter %}
{%- endif %}
Expand Down Expand Up @@ -970,14 +968,12 @@ inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
{%- endfilter %}
{%- endfor %}

// initial values for state variables in ODE or kernel
// initial values for state variables in ODEs
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
{%- endfilter %}
{%- endif -%}
{%- endfor %}

{{neuron_parent_class}}::get_status( __d );
Expand Down Expand Up @@ -1021,14 +1017,12 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
{%- endfilter %}
{%- endfor %}

// initial values for state variables in ODE or kernel
// initial values for state variables in ODEs
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
{%- endfilter %}
{%- endif %}
{%- filter indent(2) %}
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
{%- endfilter %}
{%- endfor %}

// We now know that (ptmp, stmp) are consistent. We do not
Expand All @@ -1047,11 +1041,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)

{%- for variable_symbol in neuron.get_state_symbols() -%}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
{%- endfilter %}
{%- endif %}
{%- filter indent(2) %}
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
{%- endfilter %}
{%- endfor %}

{% for invariant in neuron.get_parameter_invariants() %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,17 +822,6 @@ public:
{%- endfilter %}
}

/**
* update all convolutions with pre spikes
**/

{%- for spike_updates_for_port in spike_updates.values() %}
{%- for spike_update in spike_updates_for_port %}
{{ printer.print(spike_update.get_variable()) }} += 1.; // XXX: TODO: increment with initial value instead of 1
{%- endfor %}
{%- endfor %}


/**
* in case pre and post spike time coincide and pre update takes priority
**/
Expand Down Expand Up @@ -1025,7 +1014,7 @@ void
{%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) and not variable_symbol.is_inline_expression %}
{%- if not isHomogeneous and not variable_symbol.is_inline_expression %}
{%- if variable.get_name() == nest_codegen_opt_delay_variable %}
// special treatment of NEST delay
double tmp_{{ nest_codegen_opt_delay_variable }} = get_delay();
Expand Down Expand Up @@ -1061,7 +1050,7 @@ if (__d->known(nest::names::weight))
{%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %}
{%- if not isHomogeneous %}
{%- if variable.get_name() == nest_codegen_opt_delay_variable %}
// special treatment of NEST delay
set_delay(tmp_{{ nest_codegen_opt_delay_variable }});
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class Neuron_{{neuronName}}(Neuron):
{%- endif %}
{%- endfor %}
{%- endfilter %}
pass
else:
# internals V_
{%- filter indent(6) %}
Expand Down Expand Up @@ -220,10 +221,8 @@ class Neuron_{{neuronName}}(Neuron):
# -------------------------------------------------------------------------
{% filter indent(2, True) -%}
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.get_symbol_name())) %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_py/MemberVariableGetterSetter.jinja2" %}
{%- endif %}
{%- endfor %}
{%- endfilter %}

Expand Down Expand Up @@ -264,13 +263,6 @@ class Neuron_{{neuronName}}(Neuron):
{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %}
{%- endif %}

{#- always integrate convolutions in time #}
{%- for var in analytic_state_variables %}
{%- if "__X__" in var %}
{%- set tmp = analytic_state_variables_.append(var) %}
{%- endif %}
{%- endfor %}

{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}

{%- if uses_numeric_solver %}
Expand All @@ -285,14 +277,6 @@ class Neuron_{{neuronName}}(Neuron):
def step(self, origin: float, timestep: float) -> None:
__resolution: float = timestep # do not remove, this is necessary for the resolution() function

# -------------------------------------------------------------------------
# integrate variables related to convolutions
# -------------------------------------------------------------------------

{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
{%- endwith %}

# -------------------------------------------------------------------------
# NESTML generated code for the update block
# -------------------------------------------------------------------------
Expand All @@ -306,21 +290,6 @@ class Neuron_{{neuronName}}(Neuron):
{%- endfilter %}
{%- endif %}

# -------------------------------------------------------------------------
# integrate variables related to convolutions
# -------------------------------------------------------------------------

{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %}
{%- endwith %}

# -------------------------------------------------------------------------
# process spikes from buffers
# -------------------------------------------------------------------------
{%- filter indent(4, True) -%}
{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %}
{%- endfilter %}

# -------------------------------------------------------------------------
# begin NESTML generated code for the onReceive block(s)
# -------------------------------------------------------------------------
Expand Down

This file was deleted.

4 changes: 4 additions & 0 deletions pynestml/frontend/pynestml_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.transformers.convolutions_transformer import ConvolutionsTransformer
from pynestml.transformers.transformer import Transformer
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
Expand All @@ -59,6 +60,9 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st
if options is None:
options = {}

# for all targets, add the convolutions transformer
transformers.append(ConvolutionsTransformer())

if target_name.upper() in ["NEST", "SPINNAKER"]:
from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer

Expand Down
Loading

0 comments on commit 4bd5b42

Please sign in to comment.