Skip to content

Commit

Permalink
Refactor inline expressions expansion into a transformer (#1093)
Browse files Browse the repository at this point in the history
* refactor inline expressions expansion into a transformer

* refactor inline expressions expansion into a transformer

* refactor inline expressions expansion into a transformer

* refactor inline expressions expansion into a transformer

* refactor inline expressions expansion into a transformer

---------

Co-authored-by: C.A.P. Linssen <[email protected]>
  • Loading branch information
clinssen and C.A.P. Linssen authored Aug 19, 2024
1 parent 346fcbe commit d1473bf
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 87 deletions.
8 changes: 3 additions & 5 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pynestml.symbols.real_type_symbol import RealTypeSymbol
from pynestml.symbols.unit_type_symbol import UnitTypeSymbol
from pynestml.symbols.symbol import SymbolKind
from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
from pynestml.utils.ast_utils import ASTUtils
from pynestml.utils.logger import Logger
from pynestml.utils.logger import LoggingLevel
Expand Down Expand Up @@ -322,8 +323,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
equations_block = neuron.get_equations_blocks()[0]

kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block)
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
ASTUtils.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions())
InlineExpressionExpansionTransformer().transform(neuron)
delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)

Expand Down Expand Up @@ -400,9 +400,7 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
equations_block = synapse.get_equations_blocks()[0]

kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block)
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
ASTUtils.replace_inline_expressions_through_defining_expressions(
equations_block.get_ode_equations(), equations_block.get_inline_expressions())
InlineExpressionExpansionTransformer().transform(synapse)
delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block)
ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block)

Expand Down
11 changes: 4 additions & 7 deletions pynestml/codegeneration/nest_compartmental_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbol_table.symbol_table import SymbolTable
from pynestml.symbols.symbol import SymbolKind
from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
from pynestml.utils.mechanism_processing import MechanismProcessing
from pynestml.utils.channel_processing import ChannelProcessing
from pynestml.utils.concentration_processing import ConcentrationProcessing
Expand Down Expand Up @@ -436,13 +437,9 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)

# substitute inline expressions with each other
# such that no inline expression references another inline expression
ASTUtils.make_inline_expressions_self_contained(
equations_block.get_inline_expressions())

# dereference inline_expressions inside ode equations
ASTUtils.replace_inline_expressions_through_defining_expressions(
equations_block.get_ode_equations(), equations_block.get_inline_expressions())
# such that no inline expression references another inline expression;
# deference inline_expressions inside ode_equations
InlineExpressionExpansionTransformer().transform(neuron)

# generate update expressions using ode toolbox
# for each equation in the equation block attempt to solve analytically
Expand Down
3 changes: 1 addition & 2 deletions pynestml/frontend/pynestml_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
return CodeGenerator("", options)

# cannot reach here due to earlier assert -- silence
# cannot reach here due to earlier assert -- silence static checker warnings
assert "Unknown code generator requested: " + target_name
# static checker warnings


def builder_from_target_name(target_name: str, options: Optional[Mapping[str, Any]] = None) -> Tuple[Builder, Dict[str, Any]]:
Expand Down
142 changes: 142 additions & 0 deletions pynestml/transformers/inline_expression_expansion_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
#
# inline_expression_expansion_transformer.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import List, Optional, Mapping, Any, Union, Sequence

import re

from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_ode_equation import ASTOdeEquation
from pynestml.transformers.transformer import Transformer
from pynestml.utils.ast_utils import ASTUtils
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.string_utils import removesuffix
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor


class InlineExpressionExpansionTransformer(Transformer):
r"""
Make inline expressions self contained, i.e. without any references to other inline expressions.
Additionally, replace variable symbols referencing inline expressions in defining expressions of ODEs with the corresponding defining expressions from the inline expressions.
"""

_variable_matching_template = r'(\b)({})(\b)'

def __init__(self, options: Optional[Mapping[str, Any]] = None):
super(Transformer, self).__init__(options)

def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
single = False
if isinstance(models, ASTNode):
single = True
models = [models]

for model in models:
if not model.get_equations_blocks():
continue

for equations_block in model.get_equations_blocks():
self.make_inline_expressions_self_contained(equations_block.get_inline_expressions())

for equations_block in model.get_equations_blocks():
self.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions())

if single:
return models[0]

return models

def make_inline_expressions_self_contained(self, inline_expressions: List[ASTInlineExpression]) -> List[ASTInlineExpression]:
r"""
Make inline expressions self contained, i.e. without any references to other inline expressions.
:param inline_expressions: A sorted list with entries ASTInlineExpression.
:return: A list with ASTInlineExpressions. Defining expressions don't depend on each other.
"""
from pynestml.utils.model_parser import ModelParser
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor

for source in inline_expressions:
source_position = source.get_source_position()
for target in inline_expressions:
matcher = re.compile(self._variable_matching_template.format(source.get_variable_name()))
target_definition = str(target.get_expression())
target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition)
old_parent = target.expression.parent_
target.expression = ModelParser.parse_expression(target_definition)
target.expression.update_scope(source.get_scope())
target.expression.parent_ = old_parent
target.expression.accept(ASTParentVisitor())
target.expression.accept(ASTSymbolTableVisitor())

def log_set_source_position(node):
if node.get_source_position().is_added_source_position():
node.set_source_position(source_position)

target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))

return inline_expressions

@classmethod
def replace_inline_expressions_through_defining_expressions(self, definitions: Sequence[ASTOdeEquation],
inline_expressions: Sequence[ASTInlineExpression]) -> Sequence[ASTOdeEquation]:
r"""
Replace variable symbols referencing inline expressions in defining expressions of ODEs with the corresponding defining expressions from the inline expressions.
:param definitions: A list of ODE definitions (**updated in-place**).
:param inline_expressions: A list of inline expression definitions.
:return: A list of updated ODE definitions (same as the ``definitions`` parameter).
"""
from pynestml.utils.model_parser import ModelParser
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor

for m in inline_expressions:
if "mechanism" not in [e.namespace for e in m.get_decorators()]:
"""
exclude compartmental mechanism definitions in order to have the
inline as a barrier inbetween odes that are meant to be solved independently
"""
source_position = m.get_source_position()
for target in definitions:
matcher = re.compile(self._variable_matching_template.format(m.get_variable_name()))
target_definition = str(target.get_rhs())
target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition)
old_parent = target.rhs.parent_
target.rhs = ModelParser.parse_expression(target_definition)
target.update_scope(m.get_scope())
target.rhs.parent_ = old_parent
target.rhs.accept(ASTParentVisitor())
target.accept(ASTSymbolTableVisitor())

def log_set_source_position(node):
if node.get_source_position().is_added_source_position():
node.set_source_position(source_position)

target.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))

return definitions
71 changes: 1 addition & 70 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pynestml.utils.messages import Messages
from pynestml.utils.string_utils import removesuffix
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_visitor import ASTVisitor


Expand Down Expand Up @@ -1027,8 +1028,6 @@ def has_equation_with_delay_variable(cls, equations_with_delay_vars: ASTOdeEquat
return True
return False

_variable_matching_template = r'(\b)({})(\b)'

@classmethod
def add_declarations_to_internals(cls, neuron: ASTModel, declarations: Mapping[str, str]) -> ASTModel:
"""
Expand Down Expand Up @@ -2080,74 +2079,6 @@ def remove_ode_definitions_from_equations_block(cls, model: ASTModel) -> None:
for decl in decl_to_remove:
equations_block.get_declarations().remove(decl)

@classmethod
def make_inline_expressions_self_contained(cls, inline_expressions: List[ASTInlineExpression]) -> List[ASTInlineExpression]:
"""
Make inline_expressions self contained, i.e. without any references to other inline_expressions.
TODO: it should be a method inside of the ASTInlineExpression
TODO: this should be done by means of a visitor
:param inline_expressions: A sorted list with entries ASTInlineExpression.
:return: A list with ASTInlineExpressions. Defining expressions don't depend on each other.
"""
from pynestml.utils.model_parser import ModelParser
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor

for source in inline_expressions:
source_position = source.get_source_position()
for target in inline_expressions:
matcher = re.compile(cls._variable_matching_template.format(source.get_variable_name()))
target_definition = str(target.get_expression())
target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition)
target.expression = ModelParser.parse_expression(target_definition)
target.expression.update_scope(source.get_scope())
target.expression.accept(ASTSymbolTableVisitor())

def log_set_source_position(node):
if node.get_source_position().is_added_source_position():
node.set_source_position(source_position)

target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))

return inline_expressions

@classmethod
def replace_inline_expressions_through_defining_expressions(cls, definitions: Sequence[ASTOdeEquation],
inline_expressions: Sequence[ASTInlineExpression]) -> Sequence[ASTOdeEquation]:
"""
Replaces symbols from `inline_expressions` in `definitions` with corresponding defining expressions from `inline_expressions`.
:param definitions: A list of ODE definitions (**updated in-place**).
:param inline_expressions: A list of inline expression definitions.
:return: A list of updated ODE definitions (same as the ``definitions`` parameter).
"""
from pynestml.utils.model_parser import ModelParser
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor

for m in inline_expressions:
if "mechanism" not in [e.namespace for e in m.get_decorators()]:
"""
exclude compartmental mechanism definitions in order to have the
inline as a barrier inbetween odes that are meant to be solved independently
"""
source_position = m.get_source_position()
for target in definitions:
matcher = re.compile(cls._variable_matching_template.format(m.get_variable_name()))
target_definition = str(target.get_rhs())
target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition)
target.rhs = ModelParser.parse_expression(target_definition)
target.update_scope(m.get_scope())
target.accept(ASTSymbolTableVisitor())

def log_set_source_position(node):
if node.get_source_position().is_added_source_position():
node.set_source_position(source_position)

target.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))

return definitions

@classmethod
def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock) -> dict:
r"""
Expand Down
6 changes: 3 additions & 3 deletions pynestml/visitors/ast_symbol_table_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pynestml.meta_model.ast_model_body import ASTModelBody
from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.meta_model.ast_stmt import ASTStmt
from pynestml.meta_model.ast_variable import ASTVariable
Expand Down Expand Up @@ -473,11 +474,10 @@ def visit_variable(self, node: ASTVariable):
node.get_vector_parameter().update_scope(node.get_scope())
node.get_vector_parameter().accept(self)

def visit_inline_expression(self, node):
def visit_inline_expression(self, node: ASTInlineExpression):
"""
Private method: Used to visit a single ode-function, create the corresponding symbol and update the scope.
Private method: Used to visit a single inline expression, create the corresponding symbol and update the scope.
:param node: a single inline expression.
:type node: ASTInlineExpression
"""

# split the decorators in the AST up into namespace decorators and other decorators
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
beta_function_with_inline_expression_neuron
###########################################

Description
+++++++++++

Used for testing processing of inline expressions.


Copyright
+++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
model beta_function_with_inline_expression_neuron:

parameters:
tau1 ms = 20 ms ## decay time
tau2 ms = 10 ms ## rise time

state:
x_ pA/ms = 0 pA/ms
x pA = 0 pA

internals:
alpha real = 42.

equations:
x' = x_ - x / tau2
x_' = - x_ / tau1

recordable inline z pA = x

input:
weighted_input_spikes <- spike

output:
spike

update:
integrate_odes()

onReceive(weighted_input_spikes):
x_ += alpha * (1 / tau2 - 1 / tau1) * pA * weighted_input_spikes * s
Loading

0 comments on commit d1473bf

Please sign in to comment.