From 4bd5b42ce1bb6b1e9bae8f01fbbc0a95720eb748 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Tue, 27 Aug 2024 07:27:34 -0700 Subject: [PATCH] transform kernels and convolutions using a transformer before code generation --- pynestml/codegeneration/code_generator.py | 1 - .../codegeneration/nest_code_generator.py | 161 +---- .../nest_compartmental_code_generator.py | 238 +------ .../point_neuron/common/NeuronClass.jinja2 | 48 -- .../point_neuron/common/NeuronHeader.jinja2 | 28 +- .../common/SynapseHeader.h.jinja2 | 15 +- .../ApplySpikesFromBuffers.jinja2 | 6 - .../point_neuron/@NEURON_NAME@.py.jinja2 | 33 +- .../ApplySpikesFromBuffers.jinja2 | 6 - pynestml/frontend/pynestml_frontend.py | 4 + .../transformers/convolutions_transformer.py | 639 ++++++++++++++++++ .../synapse_post_neuron_transformer.py | 96 +-- pynestml/utils/ast_utils.py | 407 ++--------- 13 files changed, 768 insertions(+), 914 deletions(-) delete mode 100644 pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 delete mode 100644 pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 create mode 100644 pynestml/transformers/convolutions_transformer.py diff --git a/pynestml/codegeneration/code_generator.py b/pynestml/codegeneration/code_generator.py index 092ce3414..f95fafb1d 100644 --- a/pynestml/codegeneration/code_generator.py +++ b/pynestml/codegeneration/code_generator.py @@ -120,7 +120,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str # Environment for neuron templates env = Environment(loader=FileSystemLoader(_template_dirs)) env.globals["raise"] = self.raise_helper - env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel # Load all the templates _templates = list() diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 0471d0041..fade4d495 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -53,7 +53,6 @@ from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_input_port import ASTInputPort -from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node_factory import ASTNodeFactory from pynestml.meta_model.ast_ode_equation import ASTOdeEquation @@ -168,6 +167,10 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None): self.setup_printers() def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: Sequence[ASTModel]): + for model in neurons + synapses: + for equations_block in model.get_equations_blocks(): + assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer" + for synapse in synapses: synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), FrontendConfiguration.suffix) delay_variable = self.get_option("delay_variable")[synapse_name_stripped] @@ -281,9 +284,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: for neuron in neurons: code, message = Messages.get_analysing_transforming_model(neuron.get_name()) Logger.log_message(None, code, message, None, LoggingLevel.INFO) - spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron) - neuron.spike_updates = spike_updates - neuron.post_spike_updates = post_spike_updates + equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron) neuron.equations_with_delay_vars = equations_with_delay_vars neuron.equations_with_vector_vars = equations_with_vector_vars @@ -294,14 +295,12 @@ def analyse_transform_synapses(self, synapses: List[ASTModel]) -> None: """ for synapse in synapses: Logger.log_message(None, None, "Analysing/transforming synapse {}.".format(synapse.get_name()), None, LoggingLevel.INFO) - synapse.spike_updates = self.analyse_synapse(synapse) + self.analyse_synapse(synapse) def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment], List[ASTOdeEquation], List[ASTOdeEquation]]: """ Analyse and transform a single neuron. :param neuron: a single neuron. - :return: see documentation for get_spike_update_expressions() for more information. - :return: post_spike_updates: list of post-synaptic spike update expressions :return: equations_with_delay_vars: list of equations containing delay variables :return: equations_with_vector_vars: list of equations containing delay variables """ @@ -315,17 +314,14 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di ASTUtils.all_variables_defined_in_block(neuron.get_state_blocks())) ASTUtils.add_timestep_symbol(neuron) - return {}, {}, [], [] + return [], [] if len(neuron.get_equations_blocks()) > 1: raise Exception("Only one equations block per model supported for now") equations_block = neuron.get_equations_blocks()[0] - kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block) InlineExpressionExpansionTransformer().transform(neuron) - delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) - ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block) # Collect all equations with delay variables and replace ASTFunctionCall to ASTVariable wherever necessary equations_with_delay_vars_visitor = ASTEquationsWithDelayVarsVisitor() @@ -337,7 +333,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di neuron.accept(eqns_with_vector_vars_visitor) equations_with_vector_vars = eqns_with_vector_vars_visitor.equations - analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron, kernel_buffers) + analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron) self.analytic_solver[neuron.get_name()] = analytic_solver self.numeric_solver[neuron.get_name()] = numeric_solver @@ -351,23 +347,14 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di if ode_eq.get_lhs().get_name() == var.get_name(): used_in_eq = True break - for kern in equations_block.get_kernels(): - for kern_var in kern.get_variables(): - if kern_var.get_name() == var.get_name(): - used_in_eq = True - break if not used_in_eq: self.non_equations_state_variables[neuron.get_name()].append(var) - ASTUtils.remove_initial_values_for_kernels(neuron) - kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron) ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver]) ASTUtils.remove_ode_definitions_from_equations_block(neuron) - ASTUtils.create_initial_values_for_kernels(neuron, [analytic_solver, numeric_solver], kernels) ASTUtils.create_integrate_odes_combinations(neuron) ASTUtils.replace_variable_names_in_expressions(neuron, [analytic_solver, numeric_solver]) - ASTUtils.replace_convolution_aliasing_inlines(neuron) ASTUtils.add_timestep_symbol(neuron) if self.analytic_solver[neuron.get_name()] is not None: @@ -380,9 +367,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di # Update the delay parameter parameters after symbol table update ASTUtils.update_delay_parameter_in_state_vars(neuron, state_vars_before_update) - spike_updates, post_spike_updates = self.get_spike_update_expressions(neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) - - return spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars + return equations_with_delay_vars, equations_with_vector_vars def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: """ @@ -392,32 +377,24 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: code, message = Messages.get_start_processing_model(synapse.get_name()) Logger.log_message(synapse, code, message, synapse.get_source_position(), LoggingLevel.INFO) - spike_updates = {} if synapse.get_equations_blocks(): if len(synapse.get_equations_blocks()) > 1: raise Exception("Only one equations block per model supported for now") equations_block = synapse.get_equations_blocks()[0] - kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block) InlineExpressionExpansionTransformer().transform(synapse) - delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block) - ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse, kernel_buffers) self.analytic_solver[synapse.get_name()] = analytic_solver self.numeric_solver[synapse.get_name()] = numeric_solver - ASTUtils.remove_initial_values_for_kernels(synapse) - kernels = ASTUtils.remove_kernel_definitions_from_equations_block(synapse) ASTUtils.update_initial_values_for_odes(synapse, [analytic_solver, numeric_solver]) ASTUtils.remove_ode_definitions_from_equations_block(synapse) - ASTUtils.create_initial_values_for_kernels(synapse, [analytic_solver, numeric_solver], kernels) ASTUtils.create_integrate_odes_combinations(synapse) ASTUtils.replace_variable_names_in_expressions(synapse, [analytic_solver, numeric_solver]) ASTUtils.add_timestep_symbol(synapse) self.update_symbol_table(synapse) - spike_updates, _ = self.get_spike_update_expressions(synapse, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) if not self.analytic_solver[synapse.get_name()] is None: synapse = ASTUtils.add_declarations_to_internals( @@ -435,8 +412,6 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options" assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "'" - return spike_updates - def _get_model_namespace(self, astnode: ASTModel) -> Dict: namespace = {} @@ -589,8 +564,6 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast - namespace["spike_updates"] = synapse.spike_updates - synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), FrontendConfiguration.suffix) # special case for NEST delay variable (state or parameter) @@ -616,7 +589,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: if "paired_synapse" in dir(neuron): namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse namespace["paired_synapse"] = neuron.paired_synapse.get_name() - namespace["post_spike_updates"] = neuron.post_spike_updates namespace["transferred_variables"] = neuron._transferred_variables namespace["transferred_variables_syms"] = {var_name: neuron.scope.resolve_to_symbol( var_name, SymbolKind.VARIABLE) for var_name in namespace["transferred_variables"]} @@ -769,8 +741,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["numerical_state_symbols"] = numeric_state_variable_names ASTUtils.assign_numeric_non_numeric_state_variables(neuron, numeric_state_variable_names, namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None) - namespace["spike_updates"] = neuron.spike_updates - namespace["recordable_state_variables"] = [] for state_block in neuron.get_state_blocks(): for decl in state_block.get_declarations(): @@ -778,7 +748,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE) if isinstance(sym.get_type_symbol(), (UnitTypeSymbol, RealTypeSymbol)) \ - and not ASTUtils.is_delta_kernel(neuron.get_kernel_by_name(sym.name)) \ and sym.is_recordable: namespace["recordable_state_variables"].append(var) @@ -788,7 +757,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: for var in decl.get_variables(): sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE) - if sym.has_declaring_expression() and (not neuron.get_kernel_by_name(sym.name)): + if sym.has_declaring_expression(): namespace["parameter_vars_with_iv"].append(var) namespace["recordable_inline_expressions"] = [sym for sym in neuron.get_inline_expression_symbols() @@ -807,7 +776,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: return namespace - def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + def ode_toolbox_analysis(self, neuron: ASTModel): """ Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output. """ @@ -816,11 +785,11 @@ def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKern equations_block = neuron.get_equations_blocks()[0] - if len(equations_block.get_kernels()) == 0 and len(equations_block.get_ode_equations()) == 0: + if len(equations_block.get_ode_equations()) == 0: # no equations defined -> no changes to the neuron return None, None - odetoolbox_indict = ASTUtils.transform_ode_and_kernels_to_json(neuron, neuron.get_parameters_blocks(), kernel_buffers, printer=self._ode_toolbox_printer) + odetoolbox_indict = ASTUtils.transform_odes_to_json(neuron, neuron.get_parameters_blocks(), printer=self._ode_toolbox_printer) odetoolbox_indict["options"] = {} odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" disable_analytic_solver = self.get_option("solver") != "analytic" @@ -864,107 +833,3 @@ def update_symbol_table(self, neuron) -> None: symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) - - def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: - r""" - Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after - ode-toolbox. - - For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. - from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from - user specification in the model. - - Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the - initial value of the corresponding ODE dimension. - ``spike_updates`` is a mapping from input port name (as a string) to update expressions. - - ``post_spike_updates`` is a mapping from kernel name (as a string) to update expressions. - """ - spike_updates = {} - post_spike_updates = {} - - for kernel, spike_input_port in kernel_buffers: - if ASTUtils.is_delta_kernel(kernel): - continue - - spike_input_port_name = spike_input_port.get_variable().get_name() - - if not spike_input_port_name in spike_updates.keys(): - spike_updates[str(spike_input_port)] = [] - - if "_is_post_port" in dir(spike_input_port.get_variable()) \ - and spike_input_port.get_variable()._is_post_port: - # it's a port in the neuron ??? that receives post spikes ??? - orig_port_name = spike_input_port_name[:spike_input_port_name.index("__for_")] - buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol() - else: - buffer_type = neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE).get_type_symbol() - - assert not buffer_type is None - - for kernel_var in kernel.get_variables(): - for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order) - expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts) - assert expr is not None, "Initial value not found for kernel " + kernel_var - expr = str(expr) - if expr in ["0", "0.", "0.0"]: - continue # skip adding the statement if we are only adding zero - - assignment_str = kernel_spike_buf_name + " += " - if "_is_post_port" in dir(spike_input_port.get_variable()) \ - and spike_input_port.get_variable()._is_post_port: - assignment_str += "1." - else: - assignment_str += "(" + str(spike_input_port) + ")" - if not expr in ["1.", "1.0", "1"]: - assignment_str += " * (" + expr + ")" - - if not buffer_type.print_nestml_type() in ["1.", "1.0", "1", "real", "integer"]: - assignment_str += " / (" + buffer_type.print_nestml_type() + ")" - - ast_assignment = ModelParser.parse_assignment(assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - if neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE) is None: - # this case covers variables that were moved from synapse to the neuron - post_spike_updates[kernel_var.get_name()] = ast_assignment - elif "_is_post_port" in dir(spike_input_port.get_variable()) and spike_input_port.get_variable()._is_post_port: - Logger.log_message(None, None, "Adding post assignment string: " + str(ast_assignment), None, LoggingLevel.INFO) - spike_updates[str(spike_input_port)].append(ast_assignment) - else: - spike_updates[str(spike_input_port)].append(ast_assignment) - - for k, factor in delta_factors.items(): - var = k[0] - inport = k[1] - assignment_str = var.get_name() + "'" * (var.get_differential_order() - 1) + " += " - if not factor in ["1.", "1.0", "1"]: - factor_expr = ModelParser.parse_expression(factor) - factor_expr.update_scope(neuron.get_scope()) - factor_expr.accept(ASTSymbolTableVisitor()) - assignment_str += "(" + self._printer_no_origin.print(factor_expr) + ") * " - - if "_is_post_port" in dir(inport) and inport._is_post_port: - orig_port_name = inport[:inport.index("__for_")] - buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol() - else: - buffer_type = neuron.get_scope().resolve_to_symbol(inport.get_name(), SymbolKind.VARIABLE).get_type_symbol() - - assignment_str += str(inport) - if not buffer_type.print_nestml_type() in ["1.", "1.0", "1"]: - assignment_str += " / (" + buffer_type.print_nestml_type() + ")" - ast_assignment = ModelParser.parse_assignment(assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - inport_name = inport.get_name() - if inport.has_vector_parameter(): - inport_name += "_" + str(ASTUtils.get_numeric_vector_size(inport)) - if not inport_name in spike_updates.keys(): - spike_updates[inport_name] = [] - - spike_updates[inport_name].append(ast_assignment) - - return spike_updates, post_spike_updates diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 2e0fc37b6..55c22e9de 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -47,7 +47,6 @@ from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables from pynestml.meta_model.ast_input_port import ASTInputPort -from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node_factory import ASTNodeFactory from pynestml.meta_model.ast_variable import ASTVariable @@ -128,10 +127,6 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None): self.setup_printers() - # maps kernel names to their analytic solutions separately - # this is needed for the cm_syns case - self.kernel_name_to_analytic_solver = {} - def setup_printers(self): self._constant_printer = ConstantPrinter() @@ -276,27 +271,20 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: code, message = Messages.get_analysing_transforming_model( neuron.get_name()) Logger.log_message(None, code, message, None, LoggingLevel.INFO) - spike_updates = self.analyse_neuron(neuron) - neuron.spike_updates = spike_updates + self.analyse_neuron(neuron) def create_ode_indict(self, neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): - odetoolbox_indict = self.transform_ode_and_kernels_to_json( - neuron, parameters_block, kernel_buffers) + parameters_block: ASTBlockWithVariables): + odetoolbox_indict = self.transform_odes_to_json(neuron, parameters_block) odetoolbox_indict["options"] = {} odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" return odetoolbox_indict def ode_solve_analytically(self, neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): - odetoolbox_indict = self.create_ode_indict( - neuron, parameters_block, kernel_buffers) + parameters_block: ASTBlockWithVariables): + odetoolbox_indict = self.create_ode_indict(neuron, parameters_block) full_solver_result = analysis( odetoolbox_indict, @@ -315,8 +303,7 @@ def ode_solve_analytically(self, return full_solver_result, analytic_solver - def ode_toolbox_analysis(self, neuron: ASTModel, - kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + def ode_toolbox_analysis(self, neuron: ASTModel): """ Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output. """ @@ -325,15 +312,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel, equations_block = neuron.get_equations_blocks()[0] - if len(equations_block.get_kernels()) == 0 and len( - equations_block.get_ode_equations()) == 0: + if len(equations_block.get_ode_equations()) == 0: # no equations defined -> no changes to the neuron return None, None parameters_block = neuron.get_parameters_blocks()[0] - solver_result, analytic_solver = self.ode_solve_analytically( - neuron, parameters_block, kernel_buffers) + solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block) # if numeric solver is required, generate a stepping function that # includes each state variable @@ -342,8 +327,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel, x for x in solver_result if x["solver"].startswith("numeric")] if numeric_solvers: - odetoolbox_indict = self.create_ode_indict( - neuron, parameters_block, kernel_buffers) + odetoolbox_indict = self.create_ode_indict(neuron, parameters_block) solver_result = analysis( odetoolbox_indict, disable_stiffness_check=True, @@ -381,13 +365,6 @@ def find_non_equations_state_variables(self, neuron: ASTModel): used_in_eq = True break - # check for any state variables being used by a kernel - for kern in neuron.get_equations_blocks()[0].get_kernels(): - for kern_var in kern.get_variables(): - if kern_var.get_name() == var.get_name(): - used_in_eq = True - break - # if no usage found at this point, we have a non-equation state # variable if not used_in_eq: @@ -416,26 +393,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: self.non_equations_state_variables[neuron.get_name()].extend( ASTUtils.all_variables_defined_in_block(neuron.get_state_blocks()[0])) - return [] - - # goes through all convolve() inside ode's from equations block - # if they have delta kernels, use sympy to expand the expression, then - # find the convolve calls and replace them with constant value 1 - # then return every subexpression that had that convolve() replaced - delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) - - # goes through all convolve() inside equations block - # extracts what kernel is paired with what spike buffer - # returns pairs (kernel, spike_buffer) - kernel_buffers = ASTUtils.generate_kernel_buffers( - neuron, equations_block) - - # replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d] - # done by searching for every ASTSimpleExpression inside equations_block - # which is a convolve call and substituting that call with - # newly created ASTVariable kernel__X__spike_buffer - ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block) - # substitute inline expressions with each other # such that no inline expression references another inline expression; # deference inline_expressions inside ode_equations @@ -447,16 +404,7 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # "update_expressions" key in those solvers contains a mapping # {expression1: update_expression1, expression2: update_expression2} - analytic_solver, numeric_solver = self.ode_toolbox_analysis( - neuron, kernel_buffers) - - """ - # separate analytic solutions by kernel - # this is is needed for the synaptic case - self.kernel_name_to_analytic_solver[neuron.get_name( - )] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers) - """ - + analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron) self.analytic_solver[neuron.get_name()] = analytic_solver self.numeric_solver[neuron.get_name()] = numeric_solver @@ -464,17 +412,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: self.non_equations_state_variables[neuron.get_name()] = \ self.find_non_equations_state_variables(neuron) - # gather all variables used by kernels and delete their declarations - # they will be inserted later again, but this time with values redefined - # by odetoolbox, higher order variables don't get deleted here - ASTUtils.remove_initial_values_for_kernels(neuron) - - # delete all kernels as they are all converted into buffers - # and corresponding update formulas calculated by odetoolbox - # Remember them in a variable though - kernels = ASTUtils.remove_kernel_definitions_from_equations_block( - neuron) - # Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions # their initial values are replaced by expressions suggested by ODE-toolbox. # Differential order can now be set to 0, becase they can directly represent the value of the derivative now. @@ -488,22 +425,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # corresponding updates ASTUtils.remove_ode_definitions_from_equations_block(neuron) - # restore state variables that were referenced by kernels - # and set their initial values by those suggested by ODE-toolbox - ASTUtils.create_initial_values_for_kernels( - neuron, [analytic_solver, numeric_solver], kernels) - # Inside all remaining expressions, translate all remaining variable names # according to the naming conventions of ODE-toolbox. ASTUtils.replace_variable_names_in_expressions( neuron, [analytic_solver, numeric_solver]) - # find all inline kernels defined as ASTSimpleExpression - # that have a single kernel convolution aliasing variable ('__X__') - # translate all remaining variable names according to the naming - # conventions of ODE-toolbox - ASTUtils.replace_convolution_aliasing_inlines(neuron) - # add variable __h to internals block ASTUtils.add_timestep_symbol(neuron) @@ -513,10 +439,10 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: neuron, self.analytic_solver[neuron.get_name()]["propagators"]) # generate how to calculate the next spike update - self.update_symbol_table(neuron, kernel_buffers) + self.update_symbol_table(neuron) # find any spike update expressions defined by the user spike_updates = self.get_spike_update_expressions( - neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) + neuron, [analytic_solver, numeric_solver]) return spike_updates @@ -674,20 +600,16 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast - namespace["spike_updates"] = neuron.spike_updates - namespace["recordable_state_variables"] = [ sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type( - sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel( - neuron.get_kernel_by_name( - sym.name))] + sym.get_type_symbol()) == "double" and sym.is_recordable] namespace["recordable_inline_expressions"] = [ sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type( sym.get_type_symbol()) == "double" and sym.is_recordable] # parameter symbols with initial values namespace["parameter_syms_with_iv"] = [sym for sym in neuron.get_parameter_symbols( - ) if sym.has_declaring_expression() and (not neuron.get_kernel_by_name(sym.name))] + ) if sym.has_declaring_expression()] namespace["cm_unique_suffix"] = self.getUniqueSuffix(neuron) # get the mechanisms info dictionaries and enrich them. @@ -721,7 +643,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: return namespace - def update_symbol_table(self, neuron, kernel_buffers): + def update_symbol_table(self, neuron): """ Update symbol table and scope. """ @@ -742,7 +664,7 @@ def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]: return None def create_initial_values_for_ode_toolbox_odes( - self, neuron, solver_dicts, kernel_buffers, kernels): + self, neuron, solver_dicts): """ Add the variables used in ODEs from the ode-toolbox result dictionary as ODEs in NESTML AST. """ @@ -763,101 +685,17 @@ def create_initial_values_for_ode_toolbox_odes( # here, overwrite is allowed because initial values might be # repeated between numeric and analytic solver - if ASTUtils.variable_in_kernels(var_name, kernels): - expr = "0" # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0 - if not ASTUtils.declaration_in_state_block(neuron, var_name): ASTUtils.add_declaration_to_state_block( neuron, var_name, expr) - def get_spike_update_expressions( + def transform_odes_to_json( self, neuron: ASTModel, - kernel_buffers, - solver_dicts, - delta_factors) -> List[ASTAssignment]: - """ - Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after ode-toolbox. - - For example, a resulting `assignment_str` could be "I_kernel_in += (in_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. - - Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the initial value of the corresponding ODE dimension. - - XXX: TODO: update this function signature (+ templates) to match NESTCodegenerator::get_spike_update_expressions(). - - - """ - spike_updates = [] - - for kernel, spike_input_port in kernel_buffers: - if neuron.get_scope().resolve_to_symbol( - str(spike_input_port), SymbolKind.VARIABLE) is None: - continue - - buffer_type = neuron.get_scope().resolve_to_symbol( - str(spike_input_port), SymbolKind.VARIABLE).get_type_symbol() - - if ASTUtils.is_delta_kernel(kernel): - continue - - for kernel_var in kernel.get_variables(): - for var_order in range( - ASTUtils.get_kernel_var_order_from_ode_toolbox_result( - kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, var_order) - expr = ASTUtils.get_initial_value_from_ode_toolbox_result( - kernel_spike_buf_name, solver_dicts) - assert expr is not None, "Initial value not found for kernel " + kernel_var - expr = str(expr) - if expr in ["0", "0.", "0.0"]: - continue # skip adding the statement if we're only adding zero - - assignment_str = kernel_spike_buf_name + " += " - assignment_str += "(" + str(spike_input_port) + ")" - if expr not in ["1.", "1.0", "1"]: - assignment_str += " * (" + expr + ")" - - if not buffer_type.print_nestml_type() in ["1.", "1.0", "1"]: - assignment_str += " / (" + buffer_type.print_nestml_type() + ")" - - ast_assignment = ModelParser.parse_assignment( - assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - spike_updates.append(ast_assignment) - - for k, factor in delta_factors.items(): - var = k[0] - inport = k[1] - assignment_str = var.get_name() + "'" * (var.get_differential_order() - 1) + " += " - if factor not in ["1.", "1.0", "1"]: - assignment_str += "(" + self._printer.print(ModelParser.parse_expression(factor)) + ") * " - assignment_str += str(inport) - ast_assignment = ModelParser.parse_assignment(assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - spike_updates.append(ast_assignment) - - return spike_updates - - def transform_ode_and_kernels_to_json( - self, - neuron: ASTModel, - parameters_block, - kernel_buffers): + parameters_block): """ Converts AST node to a JSON representation suitable for passing to ode-toolbox. - Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements - - convolve(G, ex_spikes) - convolve(G, in_spikes) - - then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. - :param parameters_block: ASTBlockWithVariables :return: Dict """ @@ -866,8 +704,7 @@ def transform_ode_and_kernels_to_json( equations_block = neuron.get_equations_blocks()[0] for equation in equations_block.get_ode_equations(): - # n.b. includes single quotation marks to indicate differential - # order + # n.b. includes single quotation marks to indicate differential order lhs = ASTUtils.to_ode_toolbox_name( equation.get_lhs().get_complete_name()) rhs = self._ode_toolbox_printer.print(equation.get_rhs()) @@ -887,43 +724,6 @@ def transform_ode_and_kernels_to_json( iv_symbol_name)] = expr odetoolbox_indict["dynamics"].append(entry) - # write a copy for each (kernel, spike buffer) combination - for kernel, spike_input_port in kernel_buffers: - - if ASTUtils.is_delta_kernel(kernel): - # delta function -- skip passing this to ode-toolbox - continue - - for kernel_var in kernel.get_variables(): - expr = ASTUtils.get_expr_from_kernel_var( - kernel, kernel_var.get_complete_name()) - kernel_order = kernel_var.get_differential_order() - kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") - - ASTUtils.replace_rhs_variables(expr, kernel_buffers) - - entry = {} - entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr) - - # initial values need to be declared for order 1 up to kernel - # order (e.g. none for kernel function f(t) = ...; 1 for kernel - # ODE f'(t) = ...; 2 for f''(t) = ... and so on) - entry["initial_values"] = {} - for order in range(kernel_order): - iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") - symbol_name_ = kernel_var.get_name() + "'" * order - symbol = equations_block.get_scope().resolve_to_symbol( - symbol_name_, SymbolKind.VARIABLE) - assert symbol is not None, "Could not find initial value for variable " + symbol_name_ - initial_value_expr = symbol.get_declaring_expression() - assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ - entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print( - initial_value_expr) - - odetoolbox_indict["dynamics"].append(entry) - odetoolbox_indict["parameters"] = {} if parameters_block is not None: for decl in parameters_block.get_declarations(): diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 index a59133e33..54e7d8f13 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 @@ -261,10 +261,8 @@ std::vector< std::tuple< int, int > > {{ neuronName }}::rport_to_nestml_buffer_i // copy state struct S_ {%- for init in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %} {%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %} {{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }}; -{%- endif %} {%- endfor %} // copy internals V_ @@ -723,28 +721,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const update_delay_variables(); {%- endif %} - /** - * subthreshold updates of the convolution variables - * - * step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions - **/ - -{%- if uses_analytic_solver %} -{%- for variable_name in analytic_state_variables: %} -{%- if "__X__" in variable_name %} -{%- set update_expr = update_expressions[variable_name] %} -{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%} -{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%} -{%- if use_gap_junctions %} - const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}; -{%- else %} - const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }}; -{%- endif %} -{%- endif %} -{%- endfor %} -{%- endif %} - - /** * Begin NESTML generated code for the update block(s) **/ @@ -774,30 +750,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const } {%- endfor %} - /** - * subthreshold updates of the convolution variables - * - * step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep. - **/ -{% if uses_analytic_solver %} -{%- for variable_name in analytic_state_variables: %} -{%- if "__X__" in variable_name %} -{%- set update_expr = update_expressions[variable_name] %} -{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%} -{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%} - {{ printer.print(var_ast) }} = {{variable_name}}__tmp_; -{%- endif %} -{%- endfor %} -{%- endif %} - - - /** - * spike updates due to convolutions - **/ -{% filter indent(4) %} -{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %} -{%- endfilter %} - /** * Begin NESTML generated code for the onCondition block(s) **/ diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 index cd509d944..8b2940cb5 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 @@ -348,10 +348,8 @@ public: {% filter indent(2, True) -%} {%- for variable_symbol in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %} -{% endif %} +{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %} {% endfor %} {%- endfilter %} {%- endif %} @@ -970,14 +968,12 @@ inline void {{neuronName}}::get_status(DictionaryDatum &__d) const {%- endfilter %} {%- endfor %} - // initial values for state variables in ODE or kernel + // initial values for state variables in ODEs {%- for variable_symbol in neuron.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} {%- filter indent(2) %} {%- include "directives_cpp/WriteInDictionary.jinja2" %} {%- endfilter %} -{%- endif -%} {%- endfor %} {{neuron_parent_class}}::get_status( __d ); @@ -1021,14 +1017,12 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d) {%- endfilter %} {%- endfor %} - // initial values for state variables in ODE or kernel + // initial values for state variables in ODEs {%- for variable_symbol in neuron.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2) %} -{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} -{%- endfilter %} -{%- endif %} +{%- filter indent(2) %} +{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} +{%- endfilter %} {%- endfor %} // We now know that (ptmp, stmp) are consistent. We do not @@ -1047,11 +1041,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d) {%- for variable_symbol in neuron.get_state_symbols() -%} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2) %} -{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} -{%- endfilter %} -{%- endif %} +{%- filter indent(2) %} +{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} +{%- endfilter %} {%- endfor %} {% for invariant in neuron.get_parameter_invariants() %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 index 0f0570656..992294720 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 @@ -822,17 +822,6 @@ public: {%- endfilter %} } - /** - * update all convolutions with pre spikes - **/ - -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for spike_update in spike_updates_for_port %} - {{ printer.print(spike_update.get_variable()) }} += 1.; // XXX: TODO: increment with initial value instead of 1 -{%- endfor %} -{%- endfor %} - - /** * in case pre and post spike time coincide and pre update takes priority **/ @@ -1025,7 +1014,7 @@ void {%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %} {%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %} {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) and not variable_symbol.is_inline_expression %} +{%- if not isHomogeneous and not variable_symbol.is_inline_expression %} {%- if variable.get_name() == nest_codegen_opt_delay_variable %} // special treatment of NEST delay double tmp_{{ nest_codegen_opt_delay_variable }} = get_delay(); @@ -1061,7 +1050,7 @@ if (__d->known(nest::names::weight)) {%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %} {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %} -{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %} +{%- if not isHomogeneous %} {%- if variable.get_name() == nest_codegen_opt_delay_variable %} // special treatment of NEST delay set_delay(tmp_{{ nest_codegen_opt_delay_variable }}); diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 deleted file mode 100644 index 881257451..000000000 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{% if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for ast in spike_updates_for_port -%} -{%- include "directives_cpp/Assignment.jinja2" %} -{%- endfor %} -{%- endfor %} diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 index 5fc3ae589..87c892608 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 @@ -191,6 +191,7 @@ class Neuron_{{neuronName}}(Neuron): {%- endif %} {%- endfor %} {%- endfilter %} + pass else: # internals V_ {%- filter indent(6) %} @@ -220,10 +221,8 @@ class Neuron_{{neuronName}}(Neuron): # ------------------------------------------------------------------------- {% filter indent(2, True) -%} {%- for variable_symbol in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.get_symbol_name())) %} {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- include "directives_py/MemberVariableGetterSetter.jinja2" %} -{%- endif %} {%- endfor %} {%- endfilter %} @@ -264,13 +263,6 @@ class Neuron_{{neuronName}}(Neuron): {%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %} {%- endif %} -{#- always integrate convolutions in time #} -{%- for var in analytic_state_variables %} -{%- if "__X__" in var %} -{%- set tmp = analytic_state_variables_.append(var) %} -{%- endif %} -{%- endfor %} - {%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %} {%- if uses_numeric_solver %} @@ -285,14 +277,6 @@ class Neuron_{{neuronName}}(Neuron): def step(self, origin: float, timestep: float) -> None: __resolution: float = timestep # do not remove, this is necessary for the resolution() function - # ------------------------------------------------------------------------- - # integrate variables related to convolutions - # ------------------------------------------------------------------------- - -{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %} -{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %} -{%- endwith %} - # ------------------------------------------------------------------------- # NESTML generated code for the update block # ------------------------------------------------------------------------- @@ -306,21 +290,6 @@ class Neuron_{{neuronName}}(Neuron): {%- endfilter %} {%- endif %} - # ------------------------------------------------------------------------- - # integrate variables related to convolutions - # ------------------------------------------------------------------------- - -{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %} -{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %} -{%- endwith %} - - # ------------------------------------------------------------------------- - # process spikes from buffers - # ------------------------------------------------------------------------- -{%- filter indent(4, True) -%} -{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %} -{%- endfilter %} - # ------------------------------------------------------------------------- # begin NESTML generated code for the onReceive block(s) # ------------------------------------------------------------------------- diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 deleted file mode 100644 index c0952b2f5..000000000 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{%- if tracing %}# generated by {{self._TemplateReference__context.name}}{% endif %} -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for ast in spike_updates_for_port -%} -{%- include "directives_py/Assignment.jinja2" %} -{%- endfor %} -{%- endfor %} diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 11e7c2f00..3cfc32ed3 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -37,6 +37,7 @@ from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.transformers.convolutions_transformer import ConvolutionsTransformer from pynestml.transformers.transformer import Transformer from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -59,6 +60,9 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st if options is None: options = {} + # for all targets, add the convolutions transformer + transformers.append(ConvolutionsTransformer()) + if target_name.upper() in ["NEST", "SPINNAKER"]: from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer diff --git a/pynestml/transformers/convolutions_transformer.py b/pynestml/transformers/convolutions_transformer.py new file mode 100644 index 000000000..f198a5726 --- /dev/null +++ b/pynestml/transformers/convolutions_transformer.py @@ -0,0 +1,639 @@ +# -*- coding: utf-8 -*- +# +# convolutions_transformer.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from __future__ import annotations + +from typing import Any, Dict, List, Sequence, Mapping, Optional, Tuple, Union + +import re + +import odetoolbox + +from pynestml.codegeneration.printers.ast_printer import ASTPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter +from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_assignment import ASTAssignment +from pynestml.meta_model.ast_block import ASTBlock +from pynestml.meta_model.ast_data_type import ASTDataType +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_equations_block import ASTEquationsBlock +from pynestml.meta_model.ast_expression import ASTExpression +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_input_port import ASTInputPort +from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_node_factory import ASTNodeFactory +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression +from pynestml.meta_model.ast_small_stmt import ASTSmallStmt +from pynestml.meta_model.ast_variable import ASTVariable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.real_type_symbol import RealTypeSymbol +from pynestml.symbols.symbol import SymbolKind +from pynestml.symbols.variable_symbol import BlockType +from pynestml.transformers.transformer import Transformer +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import Logger +from pynestml.utils.logger import LoggingLevel +from pynestml.utils.model_parser import ModelParser +from pynestml.utils.string_utils import removesuffix +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ConvolutionsTransformer(Transformer): + r"""For each convolution that occurs in the model, allocate one or more needed state variables and replace the convolution() calls by these variable names.""" + + _default_options = { + "convolution_separator": "__conv__", + "diff_order_symbol": "__d", + "simplify_expression": "sympy.logcombine(sympy.powsimp(sympy.expand(expr)))" + } + + def __init__(self, options: Optional[Mapping[str, Any]] = None): + super(Transformer, self).__init__(options) + + # ODE-toolbox printers + self._constant_printer = ConstantPrinter() + self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + self._ode_toolbox_printer = ODEToolboxExpressionPrinter(simple_expression_printer=UnitlessSympySimpleExpressionPrinter(variable_printer=self._ode_toolbox_variable_printer, + constant_printer=self._constant_printer, + function_call_printer=self._ode_toolbox_function_call_printer)) + self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer + self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer + + def add_restore_kernel_variables_to_start_of_timestep(self, model, solvers_json): + r"""For each integrate_odes() call in the model, append statements restoring the kernel variables to the values at the start of the timestep""" + + var_names = [] + for solver_dict in solvers_json: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + var_names.append(var_name) + + class IntegrateODEsFunctionCallVisitor(ASTVisitor): + all_args = None + + def __init__(self): + super().__init__() + + def visit_small_stmt(self, node: ASTSmallStmt): + self._visit(node) + + def visit_simple_expression(self, node: ASTSimpleExpression): + self._visit(node) + + def _visit(self, node): + if node.is_function_call() and node.get_function_call().get_name() == PredefinedFunctions.INTEGRATE_ODES: + parent_stmt = node.get_parent() + parent_block = parent_stmt.get_parent() + assert isinstance(parent_block, ASTBlock) + idx = parent_block.stmts.index(parent_stmt) + + for i, var_name in enumerate(var_names): + var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol) + var.update_scope(parent_block.get_scope()) + expr = ASTNodeFactory.create_ast_simple_expression(variable=var) + ast_assignment = ASTNodeFactory.create_ast_assignment(lhs=ASTUtils.get_variable_by_name(model, var_name), + is_direct_assignment=True, + expression=expr, source_position=ASTSourceLocation.get_added_source_position()) + ast_assignment.update_scope(parent_block.get_scope()) + ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment) + ast_small_stmt.update_scope(parent_block.get_scope()) + ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt) + ast_stmt.update_scope(parent_block.get_scope()) + + parent_block.stmts.insert(idx + i + 1, ast_stmt) + + model.accept(IntegrateODEsFunctionCallVisitor()) + + def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json): + for solver_dict in solvers_json: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + var = ASTUtils.get_variable_by_name(model, var_name) + ASTUtils.add_state_var_to_integrate_odes_calls(model, var) + + model.accept(ASTParentVisitor()) + + + def add_integrate_odes_call_for_kernel_variables(self, model, solvers_json): + var_names = [] + for solver_dict in solvers_json: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + var_names.append(var_name) + + args = ASTUtils.resolve_variables_to_simple_expressions(model, var_names) + ast_function_call = ASTNodeFactory.create_ast_function_call("integrate_odes", args) + ASTUtils.add_function_call_to_update_block(ast_function_call, model) + model.accept(ASTParentVisitor()) + + def add_temporary_kernel_variables_copy(self, model, solvers_json): + var_names = [] + for solver_dict in solvers_json: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + var_names.append(var_name) + + scope = model.get_update_blocks()[0].scope + + for var_name in var_names: + var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol) + var.scope = scope + expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name)) + ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var], + data_type=ASTDataType(is_real=True), + expression=expr, source_position=ASTSourceLocation.get_added_source_position()) + ast_declaration.update_scope(scope) + ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=ast_declaration) + ast_small_stmt.update_scope(scope) + ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt) + ast_stmt.update_scope(scope) + + model.get_update_blocks()[0].get_block().stmts.insert(0, ast_stmt) + + model.accept(ASTParentVisitor()) + model.accept(ASTSymbolTableVisitor()) + + def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + r"""Transform a model or a list of models. Return an updated model or list of models.""" + for model in models: + print("-------- MODEL BEFORE TRANSFORM ------------") + print(model) + kernel_buffers = self.generate_kernel_buffers(model) + odetoolbox_indict = self.transform_kernels_to_json(model, kernel_buffers) + print("odetoolbox indict: " + str(odetoolbox_indict)) + solvers_json, shape_sys, shapes = odetoolbox._analysis(odetoolbox_indict, + disable_stiffness_check=True, + disable_analytic_solver=True, + preserve_expressions=True, + simplify_expression=self.get_option("simplify_expression"), + log_level=FrontendConfiguration.logging_level) + print("odetoolbox outdict: " + str(solvers_json)) + + self.remove_initial_values_for_kernels(model) + self.create_initial_values_for_kernels(model, solvers_json, kernel_buffers) + self.create_spike_update_event_handlers(model, solvers_json, kernel_buffers) + self.replace_convolve_calls_with_buffers_(model) + self.remove_kernel_definitions_from_equations_blocks(model) + self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json) + self.add_restore_kernel_variables_to_start_of_timestep(model, solvers_json) + self.add_temporary_kernel_variables_copy(model, solvers_json) + self.add_integrate_odes_call_for_kernel_variables(model, solvers_json) + self.add_kernel_equations(model, solvers_json) + + print("-------- MODEL AFTER TRANSFORM ------------") + print(model) + print("-------------------------------------------") + + return models + + def construct_kernel_spike_buf_name(self, kernel_var_name: str, spike_input_port: ASTInputPort, order: int, diff_order_symbol: Optional[str] = None): + """ + Construct a kernel-buffer name as ``KERNEL_NAME__conv__INPUT_PORT_NAME`` + + For example, if the kernel is + .. code-block:: + kernel I_kernel = exp(-t / tau_x) + + and the input port is + .. code-block:: + pre_spikes nS <- spike + + then the constructed variable will be ``I_kernel__conv__pre_pikes`` + """ + assert type(kernel_var_name) is str + assert type(order) is int + + if isinstance(spike_input_port, ASTSimpleExpression): + spike_input_port = spike_input_port.get_variable() + + if not isinstance(spike_input_port, str): + spike_input_port_name = spike_input_port.get_name() + else: + spike_input_port_name = spike_input_port + + if isinstance(spike_input_port, ASTVariable): + if spike_input_port.has_vector_parameter(): + spike_input_port_name += "_" + str(self.get_numeric_vector_size(spike_input_port)) + + if not diff_order_symbol: + diff_order_symbol = self.get_option("diff_order_symbol") + + return kernel_var_name.replace("$", "__DOLLAR") + self.get_option("convolution_separator") + spike_input_port_name + diff_order_symbol * order + + def replace_rhs_variable(self, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable, + spike_buf: ASTInputPort): + """ + Replace variable names in definitions of kernel dynamics + :param expr: expression in which to replace the variables + :param variable_name_to_replace: variable name to replace in the expression + :param kernel_var: kernel variable instance + :param spike_buf: input port instance + :return: + """ + def replace_kernel_var(node): + if type(node) is ASTSimpleExpression \ + and node.is_variable() \ + and node.get_variable().get_name() == variable_name_to_replace: + var_order = node.get_variable().get_differential_order() + new_variable_name = cls.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") + new_variable = ASTVariable(new_variable_name, var_order) + new_variable.set_source_position(node.get_variable().get_source_position()) + node.set_variable(new_variable) + + expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var)) + + def replace_rhs_variables(self, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + """ + Replace variable names in definitions of kernel dynamics. + + Say that the kernel is + + .. code-block:: + + G = -G / tau + + Its variable symbol might be replaced by "G__conv__spikesEx": + + .. code-block:: + + G__conv__spikesEx = -G / tau + + This function updates the right-hand side of `expr` so that it would also read (in this example): + + .. code-block:: + + G__conv__spikesEx = -G__conv__spikesEx / tau + + These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order. + + Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`. + """ + for kernel, spike_buf in kernel_buffers: + for kernel_var in kernel.get_variables(): + variable_name_to_replace = kernel_var.get_name() + self.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace, + kernel_var=kernel_var, spike_buf=spike_buf) + + @classmethod + def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: + r""" + Remove initial values for original declarations (e.g. g_in, g_in', V_m); these will be replaced with the initial value expressions returned from ODE-toolbox. + """ + symbols_to_remove = set() + for equations_block in model.get_equations_blocks(): + for kernel in equations_block.get_kernels(): + for kernel_var in kernel.get_variables(): + kernel_var_order = kernel_var.get_differential_order() + for order in range(kernel_var_order): + symbol_name = kernel_var.get_name() + "'" * order + symbols_to_remove.add(symbol_name) + + decl_to_remove = set() + for symbol_name in symbols_to_remove: + for state_block in model.get_state_blocks(): + for decl in state_block.get_declarations(): + if len(decl.get_variables()) == 1: + if decl.get_variables()[0].get_name() == symbol_name: + decl_to_remove.add(decl) + else: + for var in decl.get_variables(): + if var.get_name() == symbol_name: + decl.variables.remove(var) + + for decl in decl_to_remove: + for state_block in model.get_state_blocks(): + if decl in state_block.get_declarations(): + state_block.get_declarations().remove(decl) + + def create_initial_values_for_kernels(self, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None: + r""" + Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST + """ + for solver_dict in solver_dicts: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + spike_in_port_name = var_name.split(self.get_option("convolution_separator"))[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) + type_str = "real" + if spike_in_port: + differential_order: int = len(re.findall("__d", var_name)) + if differential_order: + type_str = "(s**-" + str(differential_order) + ")" + + expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution) + if not ASTUtils.declaration_in_state_block(model, var_name): + ASTUtils.add_declaration_to_state_block(model, var_name, expr, type_str) + + def is_delta_kernel(self, kernel: ASTKernel) -> bool: + """ + Catches definition of kernel, or reference (function call or variable name) of a delta kernel function. + """ + if not isinstance(kernel, ASTKernel): + return False + + if len(kernel.get_variables()) != 1: + # delta kernel not allowed if more than one variable is defined in this kernel + return False + + expr = kernel.get_expressions()[0] + + rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \ + and expr.is_function_call() \ + and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) + + rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \ + and type(expr.get_rhs()) is ASTSimpleExpression \ + and expr.get_rhs().is_function_call() \ + and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) + + return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel + + def replace_convolve_calls_with_buffers_(self, model: ASTModel) -> None: + r""" + Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. + """ + + def replace_function_call_through_var(_expr=None): + if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": + convolve = _expr.get_function_call() + el = (convolve.get_args()[0], convolve.get_args()[1]) + sym = convolve.get_args()[0].get_scope().resolve_to_symbol( + convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) + if sym.block_type == BlockType.INPUT: + # swap elements + el = (el[1], el[0]) + var = el[0].get_variable() + spike_input_port = el[1].get_variable() + kernel = model.get_kernel_by_name(var.get_name()) + + _expr.set_function_call(None) + buffer_var = self.construct_kernel_spike_buf_name( + var.get_name(), spike_input_port, var.get_differential_order() - 1) + if self.is_delta_kernel(kernel): + # delta kernels are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero + _expr.set_variable(None) + _expr.set_numeric_literal(0) + else: + ast_variable = ASTVariable(buffer_var) + ast_variable.set_source_position(_expr.get_source_position()) + _expr.set_variable(ast_variable) + + def func(x): + return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True + + for equations_block in model.get_equations_blocks(): + equations_block.accept(ASTHigherOrderVisitor(func)) + + @classmethod + def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None: + """ + Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``. + """ + def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): + if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): + var = _expr.get_variable() + if var.get_name() == replace_var_name: + ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(), + differential_order=0) + ast_variable.set_source_position(var.get_source_position()) + _expr.set_variable(ast_variable) + + elif isinstance(_expr, ASTVariable): + var = _expr + if var.get_name() == replace_var_name: + var.set_name(replace_with_var_name + '__d' * var.get_differential_order()) + var.set_differential_order(0) + + for equation_block in neuron.get_equations_blocks(): + for decl in equation_block.get_declarations(): + if isinstance(decl, ASTInlineExpression): + expr = decl.get_expression() + if isinstance(expr, ASTExpression): + expr = expr.get_lhs() + + if isinstance(expr, ASTSimpleExpression) \ + and '__X__' in str(expr) \ + and expr.get_variable(): + replace_with_var_name = expr.get_variable().get_name() + neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var( + x, decl.get_variable_name(), replace_with_var_name))) + + def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInputPort]: + r""" + For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`. + """ + kernel_buffers = set() + for equations_block in model.get_equations_blocks(): + convolve_calls = ASTUtils.get_convolve_function_calls(equations_block) + for convolve in convolve_calls: + el = (convolve.get_args()[0], convolve.get_args()[1]) + sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) + if sym is None: + raise Exception("No initial value(s) defined for kernel with variable \"" + + convolve.get_args()[0].get_variable().get_complete_name() + "\"") + if sym.block_type == BlockType.INPUT: + # swap the order + el = (el[1], el[0]) + + # find the corresponding kernel object + var = el[0].get_variable() + assert var is not None + kernel = model.get_kernel_by_name(var.get_name()) + assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str( + el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model." + + el = (kernel, el[1]) + kernel_buffers.add(el) + + return kernel_buffers + + def add_kernel_equations(self, model, solver_dicts): + if not model.get_equations_blocks(): + ASTUtils.create_equations_block() + + assert len(model.get_equations_blocks()) <= 1 + + equations_block = model.get_equations_blocks()[0] + + for solver_dict in solver_dicts: + if solver_dict is None: + continue + + for var_name, expr_str in solver_dict["update_expressions"].items(): + expr = ModelParser.parse_expression(expr_str) + expr.update_scope(model.get_scope()) + expr.accept(ASTSymbolTableVisitor()) + + var = ASTNodeFactory.create_ast_variable(var_name, differential_order=1, source_position=ASTSourceLocation.get_added_source_position()) + var.update_scope(equations_block.get_scope()) + ast_ode_equation = ASTNodeFactory.create_ast_ode_equation(lhs=var, rhs=expr, source_position=ASTSourceLocation.get_added_source_position()) + ast_ode_equation.update_scope(equations_block.get_scope()) + equations_block.declarations.append(ast_ode_equation) + + model.accept(ASTParentVisitor()) + model.accept(ASTSymbolTableVisitor()) + + def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration: + r""" + Removes all kernels in equations blocks. + """ + for equations_block in model.get_equations_blocks(): + decl_to_remove = set() + for decl in equations_block.get_declarations(): + if type(decl) is ASTKernel: + decl_to_remove.add(decl) + + for decl in decl_to_remove: + equations_block.get_declarations().remove(decl) + + def transform_kernels_to_json(self, model: ASTModel, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Dict: + """ + Converts AST node to a JSON representation suitable for passing to ode-toolbox. + + Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements + + .. code-block:: + + convolve(G, exc_spikes) + convolve(G, inh_spikes) + + then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`. + """ + odetoolbox_indict = {} + odetoolbox_indict["dynamics"] = [] + + for kernel, spike_input_port in kernel_buffers: + + if self.is_delta_kernel(kernel): + # delta function -- skip passing this to ode-toolbox + continue + + for kernel_var in kernel.get_variables(): + expr = ASTUtils.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name()) + kernel_order = kernel_var.get_differential_order() + kernel_X_spike_buf_name_ticks = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") + + self.replace_rhs_variables(expr, kernel_buffers) + + entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} + + # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function + # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on) + for order in range(kernel_order): + iv_sym_name_ode_toolbox = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") + symbol_name_ = kernel_var.get_name() + "'" * order + symbol = model.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE) + assert symbol is not None, "Could not find initial value for variable " + symbol_name_ + initial_value_expr = symbol.get_declaring_expression() + assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ + entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + + odetoolbox_indict["parameters"] = {} + for parameters_block in model.get_parameters_blocks(): + for decl in parameters_block.get_declarations(): + for var in decl.variables: + odetoolbox_indict["parameters"][var.get_complete_name()] = self._ode_toolbox_printer.print(decl.get_expression()) + + return odetoolbox_indict + + def create_spike_update_event_handlers(self, model: ASTModel, solver_dicts, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: + r""" + Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after + ode-toolbox. + + For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. + from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from + user specification in the model. + + Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the + initial value of the corresponding ODE dimension. + """ + + spike_in_port_to_stmts = {} + for solver_dict in solver_dicts: + for var, expr in solver_dict["initial_values"].items(): + expr = str(expr) + if expr in ["0", "0.", "0.0"]: + continue # skip adding the statement if we are only adding zero + + spike_in_port_name = var.split(self.get_option("convolution_separator"))[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) + type_str = "real" + + assert spike_in_port + differential_order: int = len(re.findall("__d", var)) + if differential_order: + type_str = "(s**-" + str(differential_order) + ")" + + assignment_str = var + " += " + assignment_str += "(" + str(spike_in_port_name) + ")" + if not expr in ["1.", "1.0", "1"]: + assignment_str += " * (" + expr + ")" + + ast_assignment = ModelParser.parse_assignment(assignment_str) + ast_assignment.update_scope(model.get_scope()) + ast_assignment.accept(ASTSymbolTableVisitor()) + + ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment) + ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt) + + if not spike_in_port_name in spike_in_port_to_stmts.keys(): + spike_in_port_to_stmts[spike_in_port_name] = [] + + spike_in_port_to_stmts[spike_in_port_name].append(ast_stmt) + + # for every input port, add an onreceive block with its update statements + for in_port, stmts in spike_in_port_to_stmts.items(): + stmts_block = ASTNodeFactory.create_ast_block(stmts, ASTSourceLocation.get_added_source_position()) + on_receive_block = ASTNodeFactory.create_ast_on_receive_block(stmts_block, + in_port, + const_parameters=None, # XXX: TODO: add priority here! + source_position=ASTSourceLocation.get_added_source_position()) + + model.get_body().get_body_elements().append(on_receive_block) + + model.accept(ASTParentVisitor()) diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index b58f526c7..7df9b025e 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -165,51 +165,6 @@ def get_neuron_var_name_from_syn_port_name(self, port_name: str, neuron_name: st return None - def get_convolve_with_not_post_vars(self, nodes: Union[ASTEquationsBlock, Sequence[ASTEquationsBlock]], neuron_name: str, synapse_name: str, parent_node: ASTNode): - class ASTVariablesUsedInConvolutionVisitor(ASTVisitor): - _variables = [] - - def __init__(self, node: ASTNode, parent_node: ASTNode, codegen_class): - super(ASTVariablesUsedInConvolutionVisitor, self).__init__() - self.node = node - self.parent_node = parent_node - self.codegen_class = codegen_class - - def visit_function_call(self, node): - func_name = node.get_name() - if func_name == "convolve": - symbol_buffer = node.get_scope().resolve_to_symbol(str(node.get_args()[1]), - SymbolKind.VARIABLE) - input_port = ASTUtils.get_input_port_by_name( - self.parent_node.get_input_blocks(), symbol_buffer.name) - if input_port and not self.codegen_class.is_post_port(input_port.name, neuron_name, synapse_name): - kernel_name = node.get_args()[0].get_variable().name - self._variables.append(kernel_name) - - found_parent_assignment = False - node_ = node - while not found_parent_assignment: - node_ = node_.get_parent() - # XXX TODO also needs to accept normal ASTExpression, ASTAssignment? - if isinstance(node_, ASTInlineExpression): - found_parent_assignment = True - var_name = node_.get_variable_name() - self._variables.append(var_name) - - if not nodes: - return [] - - if isinstance(nodes, ASTNode): - nodes = [nodes] - - variables = [] - for node in nodes: - visitor = ASTVariablesUsedInConvolutionVisitor(node, parent_node, self) - node.accept(visitor) - variables.extend(visitor._variables) - - return variables - def get_all_variables_assigned_to(self, node): r"""Return a list of all variables that are assigned to in ``node``.""" class ASTAssignedToVariablesFinderVisitor(ASTVisitor): @@ -272,13 +227,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): all_state_vars = [var.get_complete_name() for var in all_state_vars] - # add names of convolutions - all_state_vars += ASTUtils.get_all_variables_used_in_convolutions(synapse.get_equations_blocks(), synapse) - - # add names of kernels - kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, synapse.get_equations_blocks()) - all_state_vars += [var.name for k in kernel_buffers for var in k[0].variables] - # exclude certain variables from being moved: # exclude any variable assigned to in any block that is not connected to a postsynaptic port strictly_synaptic_vars = ["t"] # "seed" this with the predefined variable t @@ -297,28 +245,25 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): for update_block in synapse.get_update_blocks(): strictly_synaptic_vars += self.get_all_variables_assigned_to(update_block) - # exclude convolutions if they are not with a postsynaptic variable - convolve_with_not_post_vars = self.get_convolve_with_not_post_vars(synapse.get_equations_blocks(), neuron.name, synapse.name, synapse) - # exclude all variables that depend on the ones that are not to be moved strictly_synaptic_vars_dependent = ASTUtils.recursive_dependent_variables_search(strictly_synaptic_vars, synapse) # do set subtraction - syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(convolve_with_not_post_vars) | set(strictly_synaptic_vars_dependent))) + syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(strictly_synaptic_vars_dependent))) # - # collect all the variable/parameter/kernel/function/etc. names used in defining expressions of `syn_to_neuron_state_vars` + # collect all the variable/parameter/function/etc. names used in defining expressions of `syn_to_neuron_state_vars` # recursive_vars_used = ASTUtils.recursive_necessary_variables_search(syn_to_neuron_state_vars, synapse) new_neuron.recursive_vars_used = recursive_vars_used new_neuron._transferred_variables = [neuron_state_var + var_name_suffix - for neuron_state_var in syn_to_neuron_state_vars if new_synapse.get_kernel_by_name(neuron_state_var) is None] + for neuron_state_var in syn_to_neuron_state_vars] # all state variables that will be moved from synapse to neuron syn_to_neuron_state_vars = [] for var_name in recursive_vars_used: - if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name) or ASTUtils.get_kernel_by_name(synapse, var_name): + if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name): syn_to_neuron_state_vars.append(var_name) Logger.log_message(None, -1, "State variables that will be moved from synapse to neuron: " + str(syn_to_neuron_state_vars), @@ -412,33 +357,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): block_type=BlockType.STATE, mode="move") - # - # mark variables in the neuron pertaining to synapse postsynaptic ports - # - # convolutions with them ultimately yield variable updates when post neuron calls emit_spike() - # - - def mark_post_ports(neuron, synapse, mark_node): - post_ports = [] - - def mark_post_port(_expr=None): - var = None - if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): - var = _expr.get_variable() - elif isinstance(_expr, ASTVariable): - var = _expr - - if var: - var_base_name = var.name[:-len(var_name_suffix)] # prune the suffix - if self.is_post_port(var_base_name, neuron.name, synapse.name): - post_ports.append(var) - var._is_post_port = True - - mark_node.accept(ASTHigherOrderVisitor(lambda x: mark_post_port(x))) - return post_ports - - mark_post_ports(new_neuron, new_synapse, new_neuron) - # # move statements in post receive block from synapse to new_neuron # @@ -586,6 +504,12 @@ def mark_post_port(_expr=None): return new_neuron, new_synapse def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + # check that there are no convolutions or kernels in the model (these should have been transformed out by the ConvolutionsTransformer) + for model in models: + for equations_block in model.get_equations_blocks(): + assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer" + + # transform each (neuron, synapse) pair for neuron_synapse_pair in self.get_option("neuron_synapse_pairs"): neuron_name = neuron_synapse_pair["neuron"] synapse_name = neuron_synapse_pair["synapse"] diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 2341d76b6..5f0fcccda 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -442,6 +442,25 @@ def create_internal_block(cls, model: ASTModel): model.accept(ASTParentVisitor()) return model + + @classmethod + def create_on_receive_block(cls, model: ASTModel, block: ASTBlock, input_port_name: str) -> ASTModel: + """ + Creates a single onReceive block in the handed over model. + :param model: a single model + :return: the modified model + """ + # local import since otherwise circular dependency + from pynestml.meta_model.ast_node_factory import ASTNodeFactory + block = ASTNodeFactory.create_ast_on_receive_block(block, input_port_name, + ASTSourceLocation.get_added_source_position()) + block.update_scope(model.get_scope()) + model.get_body().get_body_elements().append(block) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + + return model @classmethod def create_state_block(cls, model: ASTModel): @@ -560,6 +579,30 @@ def inline_aliases_convolution(cls, inline_expr: ASTInlineExpression) -> bool: return True return False + @classmethod + def add_state_var_to_integrate_odes_calls(cls, model: ASTModel, var: ASTExpression): + r"""Add a state variable to the arguments to each integrate_odes() calls in the model.""" + + class AddStateVarToIntegrateODEsCallsVisitor(ASTVisitor): + def visit_function_call(self, node: ASTFunctionCall): + if node.get_name() == PredefinedFunctions.INTEGRATE_ODES: + expr = ASTNodeFactory.create_ast_simple_expression(variable=var.clone()) + node.args.append(expr) + + model.accept(AddStateVarToIntegrateODEsCallsVisitor()) + + @classmethod + def resolve_variables_to_simple_expressions(cls, model, vars): + """receives a list of variable names (as strings) and returns a list of ASTSimpleExpressions containing each ASTVariable""" + expressions = [] + + for var_name in vars: + node = ASTUtils.get_variable_by_name(model, var_name) + assert node is not None + expressions.append(ASTNodeFactory.create_ast_simple_expression(variable=node)) + + return expressions + @classmethod def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str, scope=None): """add suffix to variable by given name recursively throughout astnode""" @@ -1027,7 +1070,33 @@ def has_equation_with_delay_variable(cls, equations_with_delay_vars: ASTOdeEquat if equation.get_lhs().get_name() == sym: return True return False + + @classmethod + def add_function_call_to_update_block(cls, function_call: ASTFunctionCall, model: ASTModel) -> ASTModel: + """ + Adds a single assignment to the end of the update block of the handed over model. + :param function_call: a single function call + :param neuron: a single model instance + :return: the modified model + """ + assert len(model.get_update_blocks()) <= 1, "At most one update block should be present" + if not model.get_update_blocks(): + model.create_empty_update_block() + + small_stmt = ASTNodeFactory.create_ast_small_stmt(function_call=function_call, + source_position=ASTSourceLocation.get_added_source_position()) + stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt, + source_position=ASTSourceLocation.get_added_source_position()) + model.get_update_blocks()[0].get_block().get_stmts().append(stmt) + small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope()) + stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope()) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + + return model + @classmethod def add_declarations_to_internals(cls, neuron: ASTModel, declarations: Mapping[str, str]) -> ASTModel: """ @@ -1318,119 +1387,6 @@ def all_convolution_variable_names(cls, model: ASTModel) -> List[str]: var_names = [var.get_complete_name() for var in vars if "__X__" in var.get_complete_name()] return var_names - @classmethod - def construct_kernel_X_spike_buf_name(cls, kernel_var_name: str, spike_input_port: ASTInputPort, order: int, - diff_order_symbol="__d"): - """ - Construct a kernel-buffer name as - - For example, if the kernel is - .. code-block:: - kernel I_kernel = exp(-t / tau_x) - - and the input port is - .. code-block:: - pre_spikes nS <- spike - - then the constructed variable will be 'I_kernel__X__pre_pikes' - """ - assert type(kernel_var_name) is str - assert type(order) is int - assert type(diff_order_symbol) is str - - if isinstance(spike_input_port, ASTSimpleExpression): - spike_input_port = spike_input_port.get_variable() - - if not isinstance(spike_input_port, str): - spike_input_port_name = spike_input_port.get_name() - else: - spike_input_port_name = spike_input_port - - if isinstance(spike_input_port, ASTVariable): - if spike_input_port.has_vector_parameter(): - spike_input_port_name += "_" + str(cls.get_numeric_vector_size(spike_input_port)) - - return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + spike_input_port_name + diff_order_symbol * order - - @classmethod - def replace_rhs_variable(cls, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable, - spike_buf: ASTInputPort): - """ - Replace variable names in definitions of kernel dynamics - :param expr: expression in which to replace the variables - :param variable_name_to_replace: variable name to replace in the expression - :param kernel_var: kernel variable instance - :param spike_buf: input port instance - :return: - """ - def replace_kernel_var(node): - if type(node) is ASTSimpleExpression \ - and node.is_variable() \ - and node.get_variable().get_name() == variable_name_to_replace: - var_order = node.get_variable().get_differential_order() - new_variable_name = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") - new_variable = ASTVariable(new_variable_name, var_order) - new_variable.set_source_position(node.get_variable().get_source_position()) - node.set_variable(new_variable) - - expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var)) - - @classmethod - def replace_rhs_variables(cls, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): - """ - Replace variable names in definitions of kernel dynamics. - - Say that the kernel is - - .. code-block:: - - G = -G / tau - - Its variable symbol might be replaced by "G__X__spikesEx": - - .. code-block:: - - G__X__spikesEx = -G / tau - - This function updates the right-hand side of `expr` so that it would also read (in this example): - - .. code-block:: - - G__X__spikesEx = -G__X__spikesEx / tau - - These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order. - - Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`. - """ - for kernel, spike_buf in kernel_buffers: - for kernel_var in kernel.get_variables(): - variable_name_to_replace = kernel_var.get_name() - cls.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace, - kernel_var=kernel_var, spike_buf=spike_buf) - - @classmethod - def is_delta_kernel(cls, kernel: ASTKernel) -> bool: - """ - Catches definition of kernel, or reference (function call or variable name) of a delta kernel function. - """ - if type(kernel) is ASTKernel: - if not len(kernel.get_variables()) == 1: - # delta kernel not allowed if more than one variable is defined in this kernel - return False - expr = kernel.get_expressions()[0] - else: - expr = kernel - - rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \ - and expr.is_function_call() \ - and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) - rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \ - and type(expr.get_rhs()) is ASTSimpleExpression \ - and expr.get_rhs().is_function_call() \ - and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) - return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel - @classmethod def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: str) -> ASTInputPort: """ @@ -1670,37 +1626,6 @@ def recursive_necessary_variables_search(cls, vars: List[str], model: ASTModel) return list(set(vars_used)) - @classmethod - def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: - """ - Remove initial values for original declarations (e.g. g_in, g_in', V_m); these might conflict with the initial value expressions returned from ODE-toolbox. - """ - symbols_to_remove = set() - for equations_block in model.get_equations_blocks(): - for kernel in equations_block.get_kernels(): - for kernel_var in kernel.get_variables(): - kernel_var_order = kernel_var.get_differential_order() - for order in range(kernel_var_order): - symbol_name = kernel_var.get_name() + "'" * order - symbols_to_remove.add(symbol_name) - - decl_to_remove = set() - for symbol_name in symbols_to_remove: - for state_block in model.get_state_blocks(): - for decl in state_block.get_declarations(): - if len(decl.get_variables()) == 1: - if decl.get_variables()[0].get_name() == symbol_name: - decl_to_remove.add(decl) - else: - for var in decl.get_variables(): - if var.get_name() == symbol_name: - decl.variables.remove(var) - - for decl in decl_to_remove: - for state_block in model.get_state_blocks(): - if decl in state_block.get_declarations(): - state_block.get_declarations().remove(decl) - @classmethod def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ @@ -1956,53 +1881,9 @@ def _visit(self, node): return visitor.calls @classmethod - def create_initial_values_for_kernels(cls, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None: - r""" - Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST - """ - for solver_dict in solver_dicts: - if solver_dict is None: - continue - - for var_name in solver_dict["initial_values"].keys(): - if cls.variable_in_kernels(var_name, kernels): - # original initial value expressions should have been removed to make place for ode-toolbox results - assert not cls.declaration_in_state_block(model, var_name) - - for solver_dict in solver_dicts: - if solver_dict is None: - continue - - for var_name, expr in solver_dict["initial_values"].items(): - # overwrite is allowed because initial values might be repeated between numeric and analytic solver - if cls.variable_in_kernels(var_name, kernels): - spike_in_port_name = var_name.split("__X__")[1] - spike_in_port_name = spike_in_port_name.split("__d")[0] - spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) - type_str = "real" - if spike_in_port: - differential_order: int = len(re.findall("__d", var_name)) - if differential_order: - type_str = "(s**-" + str(differential_order) + ")" - - expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution) - if not cls.declaration_in_state_block(model, var_name): - cls.add_declaration_to_state_block(model, var_name, expr, type_str) - - @classmethod - def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: Sequence[ASTBlockWithVariables], - kernel_buffers: Mapping[ASTKernel, ASTInputPort], printer: ASTPrinter) -> Dict: + def transform_odes_to_json(cls, model: ASTModel, parameters_blocks: Sequence[ASTBlockWithVariables], printer: ASTPrinter) -> Dict: """ Converts AST node to a JSON representation suitable for passing to ode-toolbox. - - Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements - - .. code-block:: - - convolve(G, exc_spikes) - convolve(G, inh_spikes) - - then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`. """ odetoolbox_indict = {} @@ -2027,37 +1908,6 @@ def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: S odetoolbox_indict["dynamics"].append(entry) - # write a copy for each (kernel, spike buffer) combination - for kernel, spike_input_port in kernel_buffers: - - if cls.is_delta_kernel(kernel): - # delta function -- skip passing this to ode-toolbox - continue - - for kernel_var in kernel.get_variables(): - expr = cls.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name()) - kernel_order = kernel_var.get_differential_order() - kernel_X_spike_buf_name_ticks = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") - - cls.replace_rhs_variables(expr, kernel_buffers) - - entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} - - # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function - # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on) - for order in range(kernel_order): - iv_sym_name_ode_toolbox = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") - symbol_name_ = kernel_var.get_name() + "'" * order - symbol = equations_block.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE) - assert symbol is not None, "Could not find initial value for variable " + symbol_name_ - initial_value_expr = symbol.get_declaring_expression() - assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ - entry["initial_values"][iv_sym_name_ode_toolbox] = printer.print(initial_value_expr) - - odetoolbox_indict["dynamics"].append(entry) - odetoolbox_indict["parameters"] = {} for parameters_block in parameters_blocks: for decl in parameters_block.get_declarations(): @@ -2079,52 +1929,6 @@ def remove_ode_definitions_from_equations_block(cls, model: ASTModel) -> None: for decl in decl_to_remove: equations_block.get_declarations().remove(decl) - @classmethod - def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock) -> dict: - r""" - For every occurrence of a convolution of the form `x^(n) = a * convolve(kernel, inport) + ...` where `kernel` is a delta function, add the element `(x^(n), inport) --> a` to the set. - """ - delta_factors = {} - - for ode_eq in equations_block.get_ode_equations(): - var = ode_eq.get_lhs() - expr = ode_eq.get_rhs() - conv_calls = ASTUtils.get_convolve_function_calls(expr) - for conv_call in conv_calls: - assert len( - conv_call.args) == 2, "convolve() function call should have precisely two arguments: kernel and spike input port" - kernel = conv_call.args[0] - if cls.is_delta_kernel(neuron.get_kernel_by_name(kernel.get_variable().get_name())): - inport = conv_call.args[1].get_variable() - expr_str = str(expr) - sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str, global_dict=odetoolbox.Shape._sympy_globals) - sympy_expr = sympy.expand(sympy_expr) - sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr(str(conv_call), global_dict=odetoolbox.Shape._sympy_globals) - factor_str = [] - for term in sympy.Add.make_args(sympy_expr): - if term.find(sympy_conv_expr): - factor_str.append(str(term.replace(sympy_conv_expr, 1))) - factor_str = " + ".join(factor_str) - delta_factors[(var, inport)] = factor_str - - return delta_factors - - @classmethod - def remove_kernel_definitions_from_equations_block(cls, model: ASTModel) -> ASTDeclaration: - r""" - Removes all kernels in equations blocks. - """ - for equations_block in model.get_equations_blocks(): - decl_to_remove = set() - for decl in equations_block.get_declarations(): - if type(decl) is ASTKernel: - decl_to_remove.add(decl) - - for decl in decl_to_remove: - equations_block.get_declarations().remove(decl) - - return decl_to_remove - @classmethod def add_timestep_symbol(cls, model: ASTModel) -> None: """ @@ -2137,70 +1941,6 @@ def add_timestep_symbol(cls, model: ASTModel) -> None: )], "\"__h\" is a reserved name, please do not use variables by this name in your NESTML file" model.add_to_internals_block(ModelParser.parse_declaration('__h ms = resolution()'), index=0) - @classmethod - def generate_kernel_buffers(cls, model: ASTModel, equations_block: Union[ASTEquationsBlock, List[ASTEquationsBlock]]) -> Mapping[ASTKernel, ASTInputPort]: - """ - For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`. - """ - - kernel_buffers = set() - convolve_calls = ASTUtils.get_convolve_function_calls(equations_block) - for convolve in convolve_calls: - el = (convolve.get_args()[0], convolve.get_args()[1]) - sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) - if sym is None: - raise Exception("No initial value(s) defined for kernel with variable \"" - + convolve.get_args()[0].get_variable().get_complete_name() + "\"") - if sym.block_type == BlockType.INPUT: - # swap the order - el = (el[1], el[0]) - - # find the corresponding kernel object - var = el[0].get_variable() - assert var is not None - kernel = model.get_kernel_by_name(var.get_name()) - assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str( - el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model." - - el = (kernel, el[1]) - kernel_buffers.add(el) - - return kernel_buffers - - @classmethod - def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None: - """ - Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``. - """ - def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): - if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): - var = _expr.get_variable() - if var.get_name() == replace_var_name: - ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(), - differential_order=0) - ast_variable.set_source_position(var.get_source_position()) - _expr.set_variable(ast_variable) - - elif isinstance(_expr, ASTVariable): - var = _expr - if var.get_name() == replace_var_name: - var.set_name(replace_with_var_name + '__d' * var.get_differential_order()) - var.set_differential_order(0) - - for equation_block in neuron.get_equations_blocks(): - for decl in equation_block.get_declarations(): - if isinstance(decl, ASTInlineExpression): - expr = decl.get_expression() - if isinstance(expr, ASTExpression): - expr = expr.get_lhs() - - if isinstance(expr, ASTSimpleExpression) \ - and '__X__' in str(expr) \ - and expr.get_variable(): - replace_with_var_name = expr.get_variable().get_name() - neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var( - x, decl.get_variable_name(), replace_with_var_name))) - @classmethod def replace_variable_names_in_expressions(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ @@ -2457,13 +2197,6 @@ def visit_variable(self, node): for expr in numeric_update_expressions.values(): expr.accept(visitor) - for update_expr_list in neuron.spike_updates.values(): - for update_expr in update_expr_list: - update_expr.accept(visitor) - - for update_expr in neuron.post_spike_updates.values(): - update_expr.accept(visitor) - for node in neuron.equations_with_delay_vars + neuron.equations_with_vector_vars: node.accept(visitor)