From 076bd8943a5a2c3e2ab1666fc5f105298b1c8801 Mon Sep 17 00:00:00 2001 From: clinssen Date: Tue, 21 May 2024 09:54:58 +0200 Subject: [PATCH] Refactor ``ASTNode.get_parent()`` to improve runtime performance (#1049) * refactor ASTNode.get_parent() to improve runtime performance --------- Co-authored-by: C.A.P. Linssen --- pynestml/cocos/co_co_all_variables_defined.py | 2 +- .../co_co_no_kernels_except_in_convolve.py | 4 +- .../co_co_resolution_func_legally_used.py | 2 +- pynestml/cocos/co_co_simple_delta_function.py | 2 +- .../co_co_vector_declaration_right_size.py | 2 +- .../meta_model/ast_arithmetic_operator.py | 20 ++--- pynestml/meta_model/ast_assignment.py | 63 ++++--------- pynestml/meta_model/ast_bit_operator.py | 25 +++--- pynestml/meta_model/ast_block.py | 32 +++---- .../meta_model/ast_block_with_variables.py | 30 +++---- .../meta_model/ast_comparison_operator.py | 26 +++--- pynestml/meta_model/ast_compound_stmt.py | 49 +++++----- pynestml/meta_model/ast_data_type.py | 35 ++++---- pynestml/meta_model/ast_declaration.py | 78 +++++++--------- pynestml/meta_model/ast_elif_clause.py | 43 ++++----- pynestml/meta_model/ast_else_clause.py | 34 ++++--- pynestml/meta_model/ast_equations_block.py | 34 +++---- pynestml/meta_model/ast_expression.py | 90 +++++++++---------- pynestml/meta_model/ast_expression_node.py | 18 +++- pynestml/meta_model/ast_for_stmt.py | 51 +++++------ pynestml/meta_model/ast_function.py | 54 +++++------ pynestml/meta_model/ast_function_call.py | 30 +++---- pynestml/meta_model/ast_if_clause.py | 43 ++++----- pynestml/meta_model/ast_if_stmt.py | 50 +++++------ pynestml/meta_model/ast_inline_expression.py | 46 +++++----- pynestml/meta_model/ast_input_block.py | 35 +++----- pynestml/meta_model/ast_input_port.py | 47 +++++----- pynestml/meta_model/ast_input_qualifier.py | 25 +++--- pynestml/meta_model/ast_kernel.py | 44 +++------ pynestml/meta_model/ast_logical_operator.py | 27 +++--- pynestml/meta_model/ast_model.py | 43 ++++----- pynestml/meta_model/ast_model_body.py | 30 ++----- .../meta_model/ast_namespace_decorator.py | 23 ++--- .../meta_model/ast_nestml_compilation_unit.py | 26 ++---- pynestml/meta_model/ast_node.py | 33 ++++--- pynestml/meta_model/ast_node_factory.py | 6 +- pynestml/meta_model/ast_ode_equation.py | 41 ++++----- pynestml/meta_model/ast_on_condition_block.py | 26 +++--- pynestml/meta_model/ast_on_receive_block.py | 24 ++--- pynestml/meta_model/ast_output_block.py | 25 +++--- pynestml/meta_model/ast_parameter.py | 30 +++---- pynestml/meta_model/ast_return_stmt.py | 32 +++---- pynestml/meta_model/ast_simple_expression.py | 50 +++++------ pynestml/meta_model/ast_small_stmt.py | 50 ++++------- pynestml/meta_model/ast_stmt.py | 34 +++---- pynestml/meta_model/ast_unary_operator.py | 25 +++--- pynestml/meta_model/ast_unit_type.py | 56 +++++------- pynestml/meta_model/ast_update_block.py | 29 +++--- pynestml/meta_model/ast_variable.py | 39 ++++---- pynestml/meta_model/ast_while_stmt.py | 42 ++++----- .../synapse_post_neuron_transformer.py | 9 +- .../transformers/synapse_remove_post_port.py | 6 ++ pynestml/utils/ast_utils.py | 41 +++++++-- pynestml/utils/chan_info_enricher.py | 4 +- pynestml/utils/logger.py | 1 - pynestml/utils/mechs_info_enricher.py | 12 ++- pynestml/utils/model_parser.py | 6 ++ pynestml/visitors/ast_parent_visitor.py | 38 ++++++++ .../CoCoVectorInNonVectorDeclaration.nestml | 1 + tests/test_symbol_table_builder.py | 3 + 60 files changed, 823 insertions(+), 1003 deletions(-) create mode 100644 pynestml/visitors/ast_parent_visitor.py diff --git a/pynestml/cocos/co_co_all_variables_defined.py b/pynestml/cocos/co_co_all_variables_defined.py index 0cabf28b4..a02adace2 100644 --- a/pynestml/cocos/co_co_all_variables_defined.py +++ b/pynestml/cocos/co_co_all_variables_defined.py @@ -101,7 +101,7 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): # check if it is part of an invariant # if it is the case, there is no "recursive" declaration # so check if the parent is a declaration and the expression the invariant - expr_par = node.get_parent(expr) + expr_par = expr.get_parent() if isinstance(expr_par, ASTDeclaration) and expr_par.get_invariant() == expr: # in this case its ok if it is recursive or defined later on continue diff --git a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py index 4c4c65873..c991ca351 100644 --- a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py +++ b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py @@ -93,11 +93,11 @@ def visit_variable(self, node: ASTNode): if not symbol.is_kernel(): continue if node.get_complete_name() == kernelName: - parent = self.__neuron_node.get_parent(node) + parent = node.get_parent() if parent is not None: if isinstance(parent, ASTKernel): continue - grandparent = self.__neuron_node.get_parent(parent) + grandparent = parent.get_parent() if grandparent is not None and isinstance(grandparent, ASTFunctionCall): grandparent_func_name = grandparent.get_name() if grandparent_func_name == 'convolve': diff --git a/pynestml/cocos/co_co_resolution_func_legally_used.py b/pynestml/cocos/co_co_resolution_func_legally_used.py index 783fd0ff2..08b4ec70d 100644 --- a/pynestml/cocos/co_co_resolution_func_legally_used.py +++ b/pynestml/cocos/co_co_resolution_func_legally_used.py @@ -62,7 +62,7 @@ def visit_simple_expression(self, node): if function_name == PredefinedFunctions.TIME_RESOLUTION: _node = node while _node: - _node = self.neuron.get_parent(_node) + _node = _node.get_parent() if isinstance(_node, ASTEquationsBlock) \ or isinstance(_node, ASTFunction): diff --git a/pynestml/cocos/co_co_simple_delta_function.py b/pynestml/cocos/co_co_simple_delta_function.py index 065ec7546..7f3de6658 100644 --- a/pynestml/cocos/co_co_simple_delta_function.py +++ b/pynestml/cocos/co_co_simple_delta_function.py @@ -44,7 +44,7 @@ def check_co_co(cls, model: ASTModel): def check_simple_delta(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name() == "delta": deltafunc = _expr.get_function_call() - parent = model.get_parent(_expr) + parent = _expr.get_parent() # check the argument if not (len(deltafunc.get_args()) == 1 diff --git a/pynestml/cocos/co_co_vector_declaration_right_size.py b/pynestml/cocos/co_co_vector_declaration_right_size.py index 61a70de42..6597f481f 100644 --- a/pynestml/cocos/co_co_vector_declaration_right_size.py +++ b/pynestml/cocos/co_co_vector_declaration_right_size.py @@ -50,7 +50,7 @@ class VectorDeclarationVisitor(ASTVisitor): def visit_variable(self, node: ASTVariable): vector_parameter = node.get_vector_parameter() if vector_parameter is not None: - if isinstance(self._neuron.get_parent(node), ASTDeclaration): + if isinstance(node.get_parent(), ASTDeclaration): # node is being declared: size should be >= 1 min_index = 1 diff --git a/pynestml/meta_model/ast_arithmetic_operator.py b/pynestml/meta_model/ast_arithmetic_operator.py index b952c1c6d..59bc3fd04 100644 --- a/pynestml/meta_model/ast_arithmetic_operator.py +++ b/pynestml/meta_model/ast_arithmetic_operator.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List from pynestml.meta_model.ast_node import ASTNode @@ -71,23 +72,20 @@ def clone(self): return dup - def get_parent(self, ast): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - return None + return [] - def equals(self, other): - # type: (ASTNode) -> bool - """ + def equals(self, other: ASTNode) -> bool: + r""" The equality method. """ if not isinstance(other, ASTArithmeticOperator): return False + return (self.is_times_op == other.is_times_op and self.is_div_op == other.is_div_op and self.is_modulo_op == other.is_modulo_op and self.is_plus_op == other.is_plus_op and self.is_minus_op == other.is_minus_op and self.is_pow_op == other.is_pow_op) diff --git a/pynestml/meta_model/ast_assignment.py b/pynestml/meta_model/ast_assignment.py index 2d4f4faa2..83e03e3e7 100644 --- a/pynestml/meta_model/ast_assignment.py +++ b/pynestml/meta_model/ast_assignment.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional +from typing import List, Optional from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_variable import ASTVariable @@ -124,34 +124,27 @@ def get_expression(self): """ return self.rhs - def get_parent(self, ast): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - if self.get_variable() is ast: - return self - if self.get_expression() is ast: - return self - if self.get_variable().get_parent(ast) is not None: - return self.get_variable().get_parent(ast) - if self.get_expression().get_parent(ast) is not None: - return self.get_expression().get_parent(ast) - return None - - def equals(self, other): - """ - The equals operation. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + if self.get_variable(): + children.append(self.get_variable()) + + if self.get_expression(): + children.append(self.get_expression()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTAssignment): return False + return (self.get_variable().equals(other.get_variable()) and self.is_compound_quotient == other.is_compound_quotient and self.is_compound_product == other.is_compound_product @@ -160,26 +153,6 @@ def equals(self, other): and self.is_direct_assignment == other.is_direct_assignment and self.get_expression().equals(other.get_expression())) - def deconstruct_compound_assignment(self): - """ - From lhs and rhs it constructs a new expression which corresponds to direct assignment. - E.g.: a += b*c -> a = a + b*c - :return: the rhs for an equivalent direct assignment. - :rtype: ast_expression - """ - from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - # TODO: get rid of this through polymorphism? - assert not self.is_direct_assignment, "Can only be invoked on a compound assignment." - - operator = self.extract_operator_from_compound_assignment() - lhs_variable = self.get_lhs_variable_as_expression() - rhs_in_brackets = self.get_bracketed_rhs_expression() - result = self.construct_equivalent_direct_assignment_rhs(operator, lhs_variable, rhs_in_brackets) - # create symbols for the new Expression: - visitor = ASTSymbolTableVisitor() - result.accept(visitor) - return result - def get_lhs_variable_as_expression(self): from pynestml.meta_model.ast_node_factory import ASTNodeFactory # TODO: maybe calculate new source positions exactly? diff --git a/pynestml/meta_model/ast_bit_operator.py b/pynestml/meta_model/ast_bit_operator.py index b9617df8e..fe5af3790 100644 --- a/pynestml/meta_model/ast_bit_operator.py +++ b/pynestml/meta_model/ast_bit_operator.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -82,23 +84,16 @@ def clone(self): return dup - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - return None + return [] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTBitOperator): return False diff --git a/pynestml/meta_model/ast_block.py b/pynestml/meta_model/ast_block.py index 61175bf9c..436378e43 100644 --- a/pynestml/meta_model/ast_block.py +++ b/pynestml/meta_model/ast_block.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -95,28 +97,16 @@ def delete_stmt(self, stmt): """ self.stmts.remove(stmt) - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for stmt in self.get_stmts(): - if stmt is ast: - return self - if stmt.get_parent(ast) is not None: - return stmt.get_parent(ast) - return None - - def equals(self, other): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return self.get_stmts() + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTBlock): return False diff --git a/pynestml/meta_model/ast_block_with_variables.py b/pynestml/meta_model/ast_block_with_variables.py index 480535bb8..db2a48084 100644 --- a/pynestml/meta_model/ast_block_with_variables.py +++ b/pynestml/meta_model/ast_block_with_variables.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -112,28 +114,16 @@ def clear(self): del self.declarations self.declarations = list() - def get_parent(self, ast=None): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - for stmt in self.get_declarations(): - if stmt is ast: - return self - if stmt.get_parent(ast) is not None: - return stmt.get_parent(ast) - return None + return self.get_declarations() - def equals(self, other=None): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTBlockWithVariables): return False diff --git a/pynestml/meta_model/ast_comparison_operator.py b/pynestml/meta_model/ast_comparison_operator.py index 22970d8a3..2c94d59e2 100644 --- a/pynestml/meta_model/ast_comparison_operator.py +++ b/pynestml/meta_model/ast_comparison_operator.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -94,26 +96,20 @@ def clone(self): return dup - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - return None + return [] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTComparisonOperator): return False + return (self.is_lt == other.is_lt and self.is_le == other.is_le and self.is_eq == other.is_eq and self.is_ne == other.is_ne and self.is_ne2 == other.is_ne2 and self.is_ge == other.is_ge and self.is_gt == other.is_gt) diff --git a/pynestml/meta_model/ast_compound_stmt.py b/pynestml/meta_model/ast_compound_stmt.py index e344f6f0c..96f497f9c 100644 --- a/pynestml/meta_model/ast_compound_stmt.py +++ b/pynestml/meta_model/ast_compound_stmt.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_for_stmt import ASTForStmt from pynestml.meta_model.ast_if_stmt import ASTIfStmt from pynestml.meta_model.ast_node import ASTNode @@ -141,48 +143,39 @@ def get_for_stmt(self): """ return self.for_stmt - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_if_stmt(): - if self.get_if_stmt() is ast: - return self - if self.get_if_stmt().get_parent(ast) is not None: - return self.get_if_stmt().get_parent(ast) + return [self.get_if_stmt()] + if self.is_while_stmt(): - if self.get_while_stmt() is ast: - return self - if self.get_while_stmt().get_parent(ast) is not None: - return self.get_while_stmt().get_parent(ast) + return [self.get_while_stmt()] + if self.is_for_stmt(): - if self.is_for_stmt() is ast: - return self - if self.get_for_stmt().get_parent(ast) is not None: - return self.get_for_stmt().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return [self.get_for_stmt()] + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTCompoundStmt): return False + if self.get_for_stmt() is not None and other.get_for_stmt() is not None and \ not self.get_for_stmt().equals(other.get_for_stmt()): return False + if self.get_while_stmt() is not None and other.get_while_stmt() is not None and \ not self.get_while_stmt().equals(other.get_while_stmt()): return False + if self.get_if_stmt() is not None and other.get_if_stmt() is not None and \ not self.get_if_stmt().equals(other.get_if_stmt()): return False + return True diff --git a/pynestml/meta_model/ast_data_type.py b/pynestml/meta_model/ast_data_type.py index 92515507a..07eec76c3 100644 --- a/pynestml/meta_model/ast_data_type.py +++ b/pynestml/meta_model/ast_data_type.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional +from typing import List, Optional from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_unit_type import ASTUnitType @@ -141,38 +141,33 @@ def set_type_symbol(self, type_symbol): '(PyNestML.AST.DataType) No or wrong type of type symbol provided (%s)!' % (type(type_symbol)) self.type_symbol = type_symbol - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_unit_type(): - if self.get_unit_type() is ast: - return self - if self.get_unit_type().get_parent(ast) is not None: - return self.get_unit_type().get_parent(ast) - return None + return [self.get_unit_type()] - def equals(self, other): - """ - The equals method. - :param other: a different object - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTDataType): return False + if not (self.is_integer == other.is_integer and self.is_real == other.is_real and self.is_string == other.is_string and self.is_boolean == other.is_boolean and self.is_void == other.is_void): return False + # only one of them uses a unit, thus false if self.is_unit_type() + other.is_unit_type() == 1: return False + if self.is_unit_type() and other.is_unit_type() and not self.get_unit_type().equals(other.get_unit_type()): return False + return True diff --git a/pynestml/meta_model/ast_declaration.py b/pynestml/meta_model/ast_declaration.py index 98692eb0b..f7d7aca03 100644 --- a/pynestml/meta_model/ast_declaration.py +++ b/pynestml/meta_model/ast_declaration.py @@ -19,12 +19,13 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional, List +from typing import Optional, List, Union from pynestml.meta_model.ast_data_type import ASTDataType from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression from pynestml.meta_model.ast_variable import ASTVariable @@ -58,7 +59,7 @@ class ASTDeclaration(ASTNode): invariant = None """ - def __init__(self, is_recordable: bool = False, is_inline_expression: bool = False, _variables: Optional[List[ASTVariable]] = None, data_type: Optional[ASTDataType] = None, size_parameter: Optional[str] = None, + def __init__(self, is_recordable: bool = False, is_inline_expression: bool = False, _variables: Optional[List[ASTVariable]] = None, data_type: Optional[ASTDataType] = None, size_parameter: Optional[Union[ASTSimpleExpression, ASTExpression]] = None, expression: Optional[ASTExpression] = None, invariant: Optional[ASTExpression] = None, decorators=None, *args, **kwargs): """ Standard constructor. @@ -148,37 +149,31 @@ def get_data_type(self): """ return self.data_type - def has_size_parameter(self): + def has_size_parameter(self) -> bool: """ Returns whether the declaration has a size parameter or not. :return: True if has size parameter, else False. - :rtype: bool """ return self.size_parameter is not None - def get_size_parameter(self): + def get_size_parameter(self) -> Optional[Union[ASTSimpleExpression, ASTExpression]]: """ Returns the size parameter. :return: the size parameter. - :rtype: str """ return self.size_parameter - def set_size_parameter(self, _parameter): + def set_size_parameter(self, _parameter: Optional[Union[ASTSimpleExpression, ASTExpression]]): """ Updates the current size parameter to a new value. :param _parameter: the size parameter - :type _parameter: str """ - assert (_parameter is not None and isinstance(_parameter, str)), \ - '(PyNestML.AST.Declaration) No or wrong type of size parameter provided (%s)!' % type(_parameter) self.size_parameter = _parameter - def has_expression(self): + def has_expression(self) -> bool: """ Returns whether the declaration has a right-hand side rhs or not. :return: True if right-hand side rhs declared, else False. - :rtype: bool """ return self.expression is not None @@ -210,60 +205,55 @@ def get_invariant(self): """ return self.invariant - def get_parent(self, ast): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for var in self.get_variables(): - if var is ast: - return self - if var.get_parent(ast) is not None: - return var.get_parent(ast) - if self.get_data_type() is ast: - return self - if self.get_data_type().get_parent(ast) is not None: - return self.get_data_type().get_parent(ast) + children = [] + children.extend(self.get_variables()) + if self.get_data_type(): + children.append(self.get_data_type()) + if self.has_expression(): - if self.get_expression() is ast: - return self - if self.get_expression().get_parent(ast) is not None: - return self.get_expression().get_parent(ast) + children.append(self.get_expression()) + if self.has_invariant(): - if self.get_invariant() is ast: - return self - if self.get_invariant().get_parent(ast) is not None: - return self.get_invariant().get_parent(ast) - return None + children.append(self.get_invariant()) - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + if self.has_size_parameter(): + children.append(self.get_size_parameter()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTDeclaration): return False + if not (self.is_inline_expression == other.is_inline_expression and self.is_recordable == other.is_recordable): return False + if self.get_size_parameter() != other.get_size_parameter(): return False + if len(self.get_variables()) != len(other.get_variables()): return False + my_vars = self.get_variables() your_vars = other.get_variables() for i in range(0, len(my_vars)): # caution, here the order is also checked if not my_vars[i].equals(your_vars[i]): return False + if self.has_invariant() + other.has_invariant() == 1: return False + if self.has_invariant() and other.has_invariant() and not self.get_invariant().equals(other.get_invariant()): return False + return self.get_data_type().equals(other.get_data_type()) and self.get_expression().equals( other.get_expression()) diff --git a/pynestml/meta_model/ast_elif_clause.py b/pynestml/meta_model/ast_elif_clause.py index d88595307..1165cd3b7 100644 --- a/pynestml/meta_model/ast_elif_clause.py +++ b/pynestml/meta_model/ast_elif_clause.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -88,32 +90,25 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_condition() is ast: - return self - if self.get_condition().get_parent(ast) is not None: - return self.get_condition().get_parent(ast) - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + if self.get_condition(): + children.append(self.get_condition()) + + if self.get_block(): + children.append(self.get_block()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTElifClause): return False + return self.get_condition().equals(other.get_condition()) and self.get_block().equals(other.get_block()) diff --git a/pynestml/meta_model/ast_else_clause.py b/pynestml/meta_model/ast_else_clause.py index f548f0eb1..55bd8fdea 100644 --- a/pynestml/meta_model/ast_else_clause.py +++ b/pynestml/meta_model/ast_else_clause.py @@ -18,6 +18,9 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -71,28 +74,21 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None + if self.get_block(): + return [self.get_block()] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTElseClause): return False + return self.get_block().equals(other.get_block()) diff --git a/pynestml/meta_model/ast_equations_block.py b/pynestml/meta_model/ast_equations_block.py index a7c3057ed..18d576138 100644 --- a/pynestml/meta_model/ast_equations_block.py +++ b/pynestml/meta_model/ast_equations_block.py @@ -19,12 +19,12 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Any, Sequence -from pynestml.meta_model.ast_kernel import ASTKernel +from typing import Any, List, Sequence +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_ode_equation import ASTOdeEquation -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression class ASTEquationsBlock(ASTNode): @@ -80,21 +80,6 @@ def get_declarations(self): """ return self.declarations - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for decl in self.get_declarations(): - if decl is ast: - return self - if decl.get_parent(ast) is not None: - return decl.get_parent(ast) - return None - def get_ode_equations(self) -> Sequence[ASTOdeEquation]: """ Returns a list of all ode equations in this block. @@ -135,11 +120,16 @@ def clear(self): del self.declarations self.declarations = list() - def equals(self, other: Any) -> bool: + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - The equals method. - :param other: a different object. - :return: True if equal, otherwise False. + return self.get_declarations() + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTEquationsBlock): return False diff --git a/pynestml/meta_model/ast_expression.py b/pynestml/meta_model/ast_expression.py index 85351fc38..c476bb58f 100644 --- a/pynestml/meta_model/ast_expression.py +++ b/pynestml/meta_model/ast_expression.py @@ -18,14 +18,17 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + from __future__ import annotations -from typing import Union + +from typing import List, Union from pynestml.meta_model.ast_expression_node import ASTExpressionNode from pynestml.meta_model.ast_logical_operator import ASTLogicalOperator from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator from pynestml.meta_model.ast_bit_operator import ASTBitOperator from pynestml.meta_model.ast_comparison_operator import ASTComparisonOperator +from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_unary_operator import ASTUnaryOperator @@ -313,59 +316,48 @@ def get_function_calls(self): ret.extend(self.get_if_not().get_function_calls()) return ret - def get_parent(self, ast=None): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_expression(): - if self.get_expression() is ast: - return self - if self.get_expression().get_parent(ast) is not None: - return self.get_expression().get_parent(ast) + return [self.get_expression()] + if self.is_unary_operator(): - if self.get_unary_operator() is ast: - return self - if self.get_unary_operator().get_parent(ast) is not None: - return self.get_unary_operator().get_parent(ast) + return [self.get_unary_operator(), self.get_rhs()] + if self.is_compound_expression(): - if self.get_lhs() is ast: - return self - if self.get_lhs().get_parent(ast) is not None: - return self.get_lhs().get_parent(ast) - if self.get_binary_operator() is ast: - return self - if self.get_binary_operator().get_parent(ast) is not None: - return self.get_binary_operator().get_parent(ast) - if self.get_rhs() is ast: - return self - if self.get_rhs().get_parent(ast) is not None: - return self.get_rhs().get_parent(ast) + children = [] + if self.get_lhs(): + children.append(self.get_lhs()) + + if self.get_binary_operator(): + children.append(self.get_binary_operator()) + + if self.get_rhs(): + children.append(self.get_rhs()) + + return children + if self.is_ternary_operator(): - if self.get_condition() is ast: - return self - if self.get_condition().get_parent(ast) is not None: - return self.get_condition().get_parent(ast) - if self.get_if_true() is ast: - return self - if self.get_if_true().get_parent(ast) is not None: - return self.get_if_true().get_parent(ast) - if self.get_if_not() is ast: - return self - if self.get_if_not().get_parent(ast) is not None: - return self.get_if_not().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + if self.get_condition(): + children.append(self.get_condition()) + + if self.get_if_true(): + children.append(self.get_if_true()) + + if self.get_if_not(): + children.append(self.get_if_not()) + + return children + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTExpression): return False diff --git a/pynestml/meta_model/ast_expression_node.py b/pynestml/meta_model/ast_expression_node.py index 33e96e0f4..397d1b38d 100644 --- a/pynestml/meta_model/ast_expression_node.py +++ b/pynestml/meta_model/ast_expression_node.py @@ -18,6 +18,9 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + +from typing import List + from abc import ABCMeta from copy import copy @@ -47,8 +50,15 @@ def type(self): def type(self, _value): self.__type = _value - def get_parent(self, ast): - pass + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + assert False - def equals(self, other): - pass + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. + """ + assert False diff --git a/pynestml/meta_model/ast_for_stmt.py b/pynestml/meta_model/ast_for_stmt.py index 97cbe96d0..399071d01 100644 --- a/pynestml/meta_model/ast_for_stmt.py +++ b/pynestml/meta_model/ast_for_stmt.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -137,35 +139,26 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - if self.get_start_from() is ast: - return self - if self.get_start_from().get_parent(ast) is not None: - return self.get_start_from().get_parent(ast) - if self.get_end_at() is ast: - return self - if self.get_end_at().get_parent(ast) is not None: - return self.get_end_at().get_parent(ast) - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + children = [] + if self.get_start_from(): + children.append(self.get_start_from()) + + if self.get_end_at(): + children.append(self.get_end_at()) + + if self.get_block(): + children.append(self.get_block()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTForStmt): return False diff --git a/pynestml/meta_model/ast_function.py b/pynestml/meta_model/ast_function.py index def470e00..842c086ad 100644 --- a/pynestml/meta_model/ast_function.py +++ b/pynestml/meta_model/ast_function.py @@ -164,52 +164,46 @@ def set_type_symbol(self, type_symbol): """ self.type_symbol = type_symbol - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for param in self.get_parameters(): - if param is ast: - return self - if param.get_parent(ast) is not None: - return param.get_parent(ast) + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + children = [] + children.extend(self.get_parameters()) + if self.has_return_type(): - if self.get_return_type() is ast: - return self - if self.get_return_type().get_parent(ast) is not None: - return self.get_return_type().get_parent(ast) - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children.append(self.get_return_type()) + + if self.get_block(): + children.append(self.get_block()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTFunction): return False + if self.get_name() != other.get_name(): return False + if len(self.get_parameters()) != len(other.get_parameters()): return False + my_parameters = self.get_parameters() your_parameters = other.get_parameters() for i in range(0, len(my_parameters)): if not my_parameters[i].equals(your_parameters[i]): return False + if self.has_return_type() + other.has_return_type() == 1: return False + if (self.has_return_type() and other.has_return_type() and not self.get_return_type().equals(other.get_return_type())): return False + return self.get_block().equals(other.get_block()) diff --git a/pynestml/meta_model/ast_function_call.py b/pynestml/meta_model/ast_function_call.py index c9966e2dc..a078e188a 100644 --- a/pynestml/meta_model/ast_function_call.py +++ b/pynestml/meta_model/ast_function_call.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -97,28 +99,16 @@ def get_args(self): """ return self.args - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - for param in self.get_args(): - if param is ast: - return self - if param.get_parent(ast) is not None: - return param.get_parent(ast) - return None + return self.get_args() - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTFunctionCall): return False diff --git a/pynestml/meta_model/ast_if_clause.py b/pynestml/meta_model/ast_if_clause.py index 2295d54d2..36b8cdc9b 100644 --- a/pynestml/meta_model/ast_if_clause.py +++ b/pynestml/meta_model/ast_if_clause.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -88,31 +90,24 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: ASTNode - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: Optional[ASTNode] + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_condition() is ast: - return self - if self.get_condition().get_parent(ast) is not None: - return self.get_condition().get_parent(ast) - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equals, otherwise False. - :rtype: bool + children = [] + + if self.get_condition(): + children.append(self.get_condition()) + + if self.get_block(): + children.append(self.get_block()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTIfClause): return False diff --git a/pynestml/meta_model/ast_if_stmt.py b/pynestml/meta_model/ast_if_stmt.py index 6ae19d494..5d2cfcfe3 100644 --- a/pynestml/meta_model/ast_if_stmt.py +++ b/pynestml/meta_model/ast_if_stmt.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -127,37 +129,25 @@ def get_else_clause(self): """ return self.else_clause - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: ASTNode - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: Optional[ASTNode] - """ - if self.get_if_clause() is ast: - return self - if self.get_if_clause().get_parent(ast) is not None: - return self.get_if_clause().get_parent(ast) - for elifClause in self.get_elif_clauses(): - if elifClause is ast: - return self - if elifClause.get_parent(ast) is not None: - return elifClause.get_parent(ast) + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + children = [] + if self.get_if_clause(): + children.append(self.get_if_clause()) + + children.extend(self.get_elif_clauses()) + if self.has_else_clause(): - if self.get_else_clause() is ast: - return self - if self.get_else_clause().get_parent(ast) is not None: - return self.get_else_clause().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equals, otherwise False. - :rtype: bool + children.append(self.get_else_clause()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTIfStmt): return False diff --git a/pynestml/meta_model/ast_inline_expression.py b/pynestml/meta_model/ast_inline_expression.py index 88a56f162..b8af0f928 100644 --- a/pynestml/meta_model/ast_inline_expression.py +++ b/pynestml/meta_model/ast_inline_expression.py @@ -19,8 +19,10 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.meta_model.ast_node import ASTNode +from typing import List + from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator +from pynestml.meta_model.ast_node import ASTNode class ASTInlineExpression(ASTNode): @@ -128,31 +130,23 @@ def get_expression(self): """ return self.expression - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - if self.get_data_type() is ast: - return self - if self.get_data_type().get_parent(ast) is not None: - return self.get_data_type().get_parent(ast) - if self.get_expression() is ast: - return self - if self.get_expression().get_parent(ast) is not None: - return self.get_expression().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + children = [] + if self.get_data_type(): + children.append(self.get_data_type()) + + if self.get_expression(): + children.append(self.get_expression()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTInlineExpression): return False diff --git a/pynestml/meta_model/ast_input_block.py b/pynestml/meta_model/ast_input_block.py index a5ebb615b..e5caa42b6 100644 --- a/pynestml/meta_model/ast_input_block.py +++ b/pynestml/meta_model/ast_input_block.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_input_port import ASTInputPort from pynestml.meta_model.ast_node import ASTNode @@ -90,36 +92,27 @@ def get_input_ports(self): """ return self.input_definitions - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - for port in self.get_input_ports(): - if port is ast: - return self - if port.get_parent(ast) is not None: - return port.get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return self.get_input_ports() + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTInputBlock): return False + if len(self.get_input_ports()) != len(other.get_input_ports()): return False + my_input_ports = self.get_input_ports() your_input_ports = other.get_input_ports() for i in range(0, len(my_input_ports)): if not my_input_ports[i].equals(your_input_ports[i]): return False + return True diff --git a/pynestml/meta_model/ast_input_port.py b/pynestml/meta_model/ast_input_port.py index 964f3504c..45bc87dbb 100644 --- a/pynestml/meta_model/ast_input_port.py +++ b/pynestml/meta_model/ast_input_port.py @@ -21,11 +21,13 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from pynestml.meta_model.ast_data_type import ASTDataType +from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_input_qualifier import ASTInputQualifier from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression from pynestml.utils.port_signal_type import PortSignalType @@ -58,7 +60,7 @@ class ASTInputPort(ASTNode): def __init__(self, name: str, signal_type: PortSignalType, - size_parameter: Optional[str] = None, + size_parameter: Optional[Union[ASTSimpleExpression, ASTExpression]] = None, data_type: Optional[ASTDataType] = None, input_qualifiers: Optional[List[ASTInputQualifier]] = None, *args, **kwargs): @@ -120,7 +122,7 @@ def has_size_parameter(self) -> bool: """ return self.size_parameter is not None - def get_size_parameter(self) -> str: + def get_size_parameter(self) -> Optional[Union[ASTSimpleExpression, ASTExpression]]: r""" Returns the size parameter. :return: the size parameter. @@ -195,48 +197,53 @@ def get_datatype(self) -> ASTDataType: """ return self.data_type - def get_parent(self, ast: ASTNode) -> Optional[ASTNode]: + def get_children(self) -> List[ASTNode]: r""" - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :return: AST if this or one of the child nodes contains the handed over element. + Returns the children of this node, if any. + :return: List of children of this node. """ + children = [] if self.has_datatype(): - if self.get_datatype() is ast: - return self - if self.get_datatype().get_parent(ast) is not None: - return self.get_datatype().get_parent(ast) + children.append(self.get_datatype()) + for qual in self.get_input_qualifiers(): - if qual is ast: - return self - if qual.get_parent(ast) is not None: - return qual.get_parent(ast) - return None + children.append(qual) + + if self.get_size_parameter(): + children.append(self.get_size_parameter()) + + return children - def equals(self, other: Any) -> bool: + def equals(self, other: ASTNode) -> bool: r""" - The equals method. - :param other: a different object. - :return: True if equal,otherwise False. + The equality method. """ if not isinstance(other, ASTInputPort): return False + if self.get_name() != other.get_name(): return False + if self.has_size_parameter() + other.has_size_parameter() == 1: return False + if (self.has_size_parameter() and other.has_size_parameter() and self.get_input_qualifiers() != other.get_size_parameter()): return False + if self.has_datatype() + other.has_datatype() == 1: return False + if self.has_datatype() and other.has_datatype() and not self.get_datatype().equals(other.get_datatype()): return False + if len(self.get_input_qualifiers()) != len(other.get_input_qualifiers()): return False + my_input_qualifiers = self.get_input_qualifiers() your_input_qualifiers = other.get_input_qualifiers() for i in range(0, len(my_input_qualifiers)): if not my_input_qualifiers[i].equals(your_input_qualifiers[i]): return False + return self.is_spike() == other.is_spike() and self.is_continuous() == other.is_continuous() diff --git a/pynestml/meta_model/ast_input_qualifier.py b/pynestml/meta_model/ast_input_qualifier.py index 8678ae660..6c34c33ec 100644 --- a/pynestml/meta_model/ast_input_qualifier.py +++ b/pynestml/meta_model/ast_input_qualifier.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List from pynestml.meta_model.ast_node import ASTNode @@ -72,24 +73,18 @@ def clone(self): return dup - def get_parent(self, ast): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: ASTNode - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: Optional[ASTNode] - """ - return None + return [] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTInputQualifier): return False + return self.is_excitatory == other.is_excitatory and self.is_inhibitory == other.is_inhibitory diff --git a/pynestml/meta_model/ast_kernel.py b/pynestml/meta_model/ast_kernel.py index 49564bd6a..e152e118f 100644 --- a/pynestml/meta_model/ast_kernel.py +++ b/pynestml/meta_model/ast_kernel.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -93,37 +95,19 @@ def get_expressions(self): """ return self.expressions - def get_parent(self, ast): - """ - Indicates whether this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: ASTNode - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: ASTNode or None - """ - for var in self.get_variables(): - if var is ast: - return self - - if var.get_parent(ast) is not None: - return var.get_parent(ast) - - for expr in self.get_expressions(): - if expr is ast: - return self - - if expr.get_parent(ast) is not None: - return expr.get_parent(ast) - - return None - - def equals(self, other): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + children.extend(self.get_variables()) + children.extend(self.get_expressions()) + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTKernel): return False diff --git a/pynestml/meta_model/ast_logical_operator.py b/pynestml/meta_model/ast_logical_operator.py index d98eb2a83..e3f3a314f 100644 --- a/pynestml/meta_model/ast_logical_operator.py +++ b/pynestml/meta_model/ast_logical_operator.py @@ -18,6 +18,9 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -67,24 +70,18 @@ def clone(self): return dup - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - return None + return [] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTLogicalOperator): return False + return self.is_logical_and == other.is_logical_and and self.is_logical_or == other.is_logical_or diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py index 55d08937c..834e56897 100644 --- a/pynestml/meta_model/ast_model.py +++ b/pynestml/meta_model/ast_model.py @@ -471,9 +471,11 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) - self.get_internals_blocks()[0].get_declarations().insert(index, declaration) declaration.update_scope(self.get_internals_blocks()[0].get_scope()) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.INTERNALS) declaration.accept(symtable_vistor) + self.get_internals_blocks()[0].accept(ASTParentVisitor()) symtable_vistor.block_type_stack.pop() def add_to_state_block(self, declaration: ASTDeclaration) -> None: @@ -488,10 +490,11 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None: self.get_state_blocks()[0].get_declarations().append(declaration) declaration.update_scope(self.get_state_blocks()[0].get_scope()) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.STATE) declaration.accept(symtable_vistor) + self.get_state_blocks()[0].accept(ASTParentVisitor()) symtable_vistor.block_type_stack.pop() from pynestml.symbols.symbol import SymbolKind assert declaration.get_variables()[0].get_scope().resolve_to_symbol( @@ -514,16 +517,6 @@ def print_comment(self, prefix: str = "") -> str: return ret - def equals(self, other: ASTNode) -> bool: - """ - The equals method. - :param other: a different object. - :return: True if equal, otherwise False. - """ - if not isinstance(other, ASTModel): - return False - return self.get_name() == other.get_name() and self.get_body().equals(other.get_body()) - def get_initial_value(self, variable_name: str): assert type(variable_name) is str @@ -652,19 +645,6 @@ def has_state_vectors(self) -> bool: return False - def get_parent(self, ast) -> Optional[ASTNode]: - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - """ - if self.get_body() is ast: - return self - if self.get_body().get_parent(ast) is not None: - return self.get_body().get_parent(ast) - return None - def set_default_delay(self, var, expr, dtype): self._default_delay_variable = var self._default_delay_expression = expr @@ -712,3 +692,18 @@ def get_input_buffers(self): or symbol.block_type == BlockType.INPUT_BUFFER_CURRENT): ret.append(symbol) return ret + + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + return [self.get_body()] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. + """ + if not isinstance(other, ASTModel): + return False + return self.get_name() == other.get_name() and self.get_body().equals(other.get_body()) diff --git a/pynestml/meta_model/ast_model_body.py b/pynestml/meta_model/ast_model_body.py index f283ddbc2..6e32561ce 100644 --- a/pynestml/meta_model/ast_model_body.py +++ b/pynestml/meta_model/ast_model_body.py @@ -225,21 +225,6 @@ def get_output_blocks(self) -> List[ASTOutputBlock]: return ret - def get_parent(self, ast=None): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for stmt in self.get_body_elements(): - if stmt is ast: - return self - if stmt.get_parent(ast) is not None: - return stmt.get_parent(ast) - return None - def get_spike_input_ports(self) -> List[ASTInputPort]: """ Returns a list of all spike input ports defined in the model. @@ -253,13 +238,16 @@ def get_spike_input_ports(self) -> List[ASTInputPort]: ret.append(port) return ret - def equals(self, other): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return self.get_body_elements() + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTModelBody): return False diff --git a/pynestml/meta_model/ast_namespace_decorator.py b/pynestml/meta_model/ast_namespace_decorator.py index fd0a0dd46..33cc63b54 100644 --- a/pynestml/meta_model/ast_namespace_decorator.py +++ b/pynestml/meta_model/ast_namespace_decorator.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional +from typing import List, Optional from pynestml.meta_model.ast_node import ASTNode @@ -66,23 +66,16 @@ def get_name(self) -> str: """ return self.name - def get_parent(self, ast: ASTNode) -> Optional[ASTNode]: + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :return: AST if this or one of the child nodes contains the handed over element. - """ - if self.get_name() is ast: - return self - elif self.get_namespace() is ast: - return self - return None + return [self.get_name(), self.get_namespace()] def equals(self, other: ASTNode) -> bool: - """ - The equals operation. - :param other: a different object. - :return: True if equal, otherwise False. + r""" + The equality method. """ if not isinstance(other, ASTNamespaceDecorator): return False diff --git a/pynestml/meta_model/ast_nestml_compilation_unit.py b/pynestml/meta_model/ast_nestml_compilation_unit.py index 0b10c7724..6e36f1fe3 100644 --- a/pynestml/meta_model/ast_nestml_compilation_unit.py +++ b/pynestml/meta_model/ast_nestml_compilation_unit.py @@ -102,26 +102,16 @@ def get_model_by_name(self, name: str) -> Optional[ASTModel]: return None - def get_parent(self, ast): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - for model in self.get_model_list(): - if model is ast: - return self - if model.get_parent(ast) is not None: - return model.get_parent(ast) - return None + return self.get_model_list() - def equals(self, other) -> bool: - """ - The equals method. - :param other: a different object - :return: True if equal, otherwise False. + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTNestMLCompilationUnit): return False diff --git a/pynestml/meta_model/ast_node.py b/pynestml/meta_model/ast_node.py index fd08c7dc7..64234fbce 100644 --- a/pynestml/meta_model/ast_node.py +++ b/pynestml/meta_model/ast_node.py @@ -19,12 +19,13 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from __future__ import annotations + from typing import Optional, List from abc import ABCMeta, abstractmethod from pynestml.symbol_table.scope import Scope - from pynestml.utils.ast_source_location import ASTSourceLocation @@ -74,25 +75,31 @@ def clone(self): pass @abstractmethod - def equals(self, other): + def equals(self, other: ASTNode) -> bool: """ The equals operation. - :param other: a different object. - :type other: object + :param other: a different AST node. :return: True if equal, otherwise False. - :rtype: bool """ pass - # todo: we can do this with a visitor instead of hard coding grammar traversals all over the place - @abstractmethod - def get_parent(self, ast): + def get_parent(self) -> Optional[ASTNode]: + """ + Get the parent of this node. + :return: The parent node """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + assert "parent_" in dir(self), "No parent known, please ensure ASTParentVisitor has been run on the AST" + + if self.parent_: + assert self in self.parent_.get_children(), "Doubly linked tree is inconsistent: please ensure ASTParentVisitor has been run on the AST" + + return self.parent_ + + @abstractmethod + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ pass diff --git a/pynestml/meta_model/ast_node_factory.py b/pynestml/meta_model/ast_node_factory.py index 4c4cd6792..da3986be9 100644 --- a/pynestml/meta_model/ast_node_factory.py +++ b/pynestml/meta_model/ast_node_factory.py @@ -151,9 +151,9 @@ def create_ast_declaration(cls, is_inline_expression: bool=False, variables=None, # type: list data_type=None, # type: ASTDataType - size_parameter=None, # type: str - expression=None, # type: Union(ASTSimpleExpression,ASTExpression) - invariant=None, # type: Union(ASTSimpleExpression,ASTExpression) + size_parameter=None, # type: Optional[Union[ASTSimpleExpression, ASTExpression]] + expression=None, # type: Optional[Union[ASTSimpleExpression, ASTExpression]] + invariant=None, # type: Optional[Union[ASTSimpleExpression, ASTExpression]] source_position=None, # type: ASTSourceLocation decorators=None, # type: list ) -> ASTDeclaration: diff --git a/pynestml/meta_model/ast_ode_equation.py b/pynestml/meta_model/ast_ode_equation.py index 20151a0e7..95567b367 100644 --- a/pynestml/meta_model/ast_ode_equation.py +++ b/pynestml/meta_model/ast_ode_equation.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_expression import ASTExpression @@ -107,31 +108,23 @@ def get_rhs(self): """ return self.rhs - def get_parent(self, ast=None): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - if self.get_lhs() is ast: - return self - if self.get_lhs().get_parent(ast) is not None: - return self.get_lhs().get_parent(ast) - if self.get_rhs() is ast: - return self - if self.get_rhs().get_parent(ast) is not None: - return self.get_rhs().get_parent(ast) - return None - - def equals(self, other=None): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + if self.get_lhs(): + children.append(self.get_lhs()) + + if self.get_rhs(): + children.append(self.get_rhs()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTOdeEquation): return False diff --git a/pynestml/meta_model/ast_on_condition_block.py b/pynestml/meta_model/ast_on_condition_block.py index dbe0c4346..d8e1ac4cd 100644 --- a/pynestml/meta_model/ast_on_condition_block.py +++ b/pynestml/meta_model/ast_on_condition_block.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, Optional, Mapping +from typing import Any, List, Optional, Mapping from pynestml.meta_model.ast_block import ASTBlock from pynestml.meta_model.ast_expression import ASTExpression @@ -82,25 +82,23 @@ def get_cond_expr(self) -> str: """ return self.cond_expr - def get_parent(self, ast: ASTNode) -> Optional[ASTNode]: + def get_children(self) -> List[ASTNode]: r""" - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :return: AST if this or one of the child nodes contains the handed over element. + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_block() is ast: - return self + children = [] + if self.cond_expr: + children.append(self.cond_expr) - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) + if self.get_block(): + children.append(self.get_block()) - return None + return children - def equals(self, other: Any) -> bool: + def equals(self, other: ASTNode) -> bool: r""" - The equals method. - :param other: a different object. - :return: True if equal, otherwise False. + The equality method. """ if not isinstance(other, ASTOnConditionBlock): return False diff --git a/pynestml/meta_model/ast_on_receive_block.py b/pynestml/meta_model/ast_on_receive_block.py index df46df279..d7ca37812 100644 --- a/pynestml/meta_model/ast_on_receive_block.py +++ b/pynestml/meta_model/ast_on_receive_block.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, Optional, Mapping +from typing import Any, List, Optional, Mapping from pynestml.meta_model.ast_block import ASTBlock from pynestml.meta_model.ast_node import ASTNode @@ -87,26 +87,18 @@ def get_port_name(self) -> str: """ return self.port_name - def get_parent(self, ast: ASTNode) -> Optional[ASTNode]: + def get_children(self) -> List[ASTNode]: r""" - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :return: AST if this or one of the child nodes contains the handed over element. + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_block() is ast: - return self + return [self.get_block()] - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - - return None - - def equals(self, other: Any) -> bool: + def equals(self, other: ASTNode) -> bool: r""" - The equals method. - :param other: a different object. - :return: True if equal, otherwise False. + The equality method. """ if not isinstance(other, ASTOnReceiveBlock): return False + return self.get_block().equals(other.get_block()) and self.port_name == other.port_name diff --git a/pynestml/meta_model/ast_output_block.py b/pynestml/meta_model/ast_output_block.py index 0ce8d193c..66a61f71d 100644 --- a/pynestml/meta_model/ast_output_block.py +++ b/pynestml/meta_model/ast_output_block.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode from pynestml.utils.port_signal_type import PortSignalType @@ -82,23 +84,18 @@ def is_continuous(self) -> bool: """ return self.type is PortSignalType.CONTINUOUS - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - return None + return [] - def equals(self, other) -> bool: - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equals, otherwise False. + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTOutputBlock): return False + return self.is_spike() == other.is_spike() and self.is_continuous() == other.is_continuous() diff --git a/pynestml/meta_model/ast_parameter.py b/pynestml/meta_model/ast_parameter.py index 1ac2c88aa..6e30f4af5 100644 --- a/pynestml/meta_model/ast_parameter.py +++ b/pynestml/meta_model/ast_parameter.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_data_type import ASTDataType from pynestml.meta_model.ast_node import ASTNode @@ -81,28 +83,18 @@ def get_data_type(self): """ return self.data_type - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_data_type() is ast: - return self - if self.get_data_type().get_parent(ast) is not None: - return self.get_data_type().get_parent(ast) - return None + return [self.get_data_type()] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTParameter): return False + return self.get_name() == other.get_name() and self.get_data_type().equals(other.get_data_type()) diff --git a/pynestml/meta_model/ast_return_stmt.py b/pynestml/meta_model/ast_return_stmt.py index 68e2d63a4..43719747a 100644 --- a/pynestml/meta_model/ast_return_stmt.py +++ b/pynestml/meta_model/ast_return_stmt.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -83,29 +85,21 @@ def get_expression(self): """ return self.expression - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.has_expression(): - if self.get_expression() is ast: - return self - if self.get_expression().get_parent(ast) is not None: - return self.get_expression().get_parent(ast) - return None + return [self.get_expression()] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTReturnStmt): return False + return self.get_expression().equals(other.get_expression()) diff --git a/pynestml/meta_model/ast_simple_expression.py b/pynestml/meta_model/ast_simple_expression.py index 444f0df09..8514f76d2 100644 --- a/pynestml/meta_model/ast_simple_expression.py +++ b/pynestml/meta_model/ast_simple_expression.py @@ -19,10 +19,11 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional, Union +from typing import List, Optional, Union from pynestml.meta_model.ast_expression_node import ASTExpressionNode from pynestml.meta_model.ast_function_call import ASTFunctionCall +from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_variable import ASTVariable from pynestml.utils.cloning_helpers import clone_numeric_literal @@ -272,25 +273,18 @@ def get_string(self): """ return self.string - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_function_call(): - if self.get_function_call() is ast: - return self - if self.get_function_call().get_parent(ast) is not None: - return self.get_function_call().get_parent(ast) - if self.variable is not None: - if self.variable is ast: - return self - if self.variable.get_parent(ast) is not None: - return self.variable.get_parent(ast) - return None + return [self.get_function_call()] + + if self.variable: + return [self.variable] + + return [] def set_variable(self, variable): """ @@ -312,33 +306,39 @@ def set_function_call(self, function_call): '(PyNestML.AST.SimpleExpression) No or wrong type of function call provided (%s)!' % type(function_call) self.function_call = function_call - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return:True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTSimpleExpression): return False + if self.is_function_call() + other.is_function_call() == 1: return False + if self.is_function_call() and other.is_function_call() and not self.get_function_call().equals( other.get_function_call()): return False + if self.get_numeric_literal() != other.get_numeric_literal(): return False + if self.is_boolean_false != other.is_boolean_false or self.is_boolean_true != other.is_boolean_true: return False + if self.is_variable() + other.is_variable() == 1: return False + if self.is_variable() and other.is_variable() and not self.get_variable().equals(other.get_variable()): return False + if self.is_inf_literal != other.is_inf_literal: return False + if self.is_string() + other.is_string() == 1: return False + if self.get_string() != other.get_string(): return False + return True diff --git a/pynestml/meta_model/ast_small_stmt.py b/pynestml/meta_model/ast_small_stmt.py index cad9fa1fd..e570084f2 100644 --- a/pynestml/meta_model/ast_small_stmt.py +++ b/pynestml/meta_model/ast_small_stmt.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List from pynestml.meta_model.ast_node import ASTNode @@ -156,43 +157,28 @@ def get_return_stmt(self): """ return self.return_stmt - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_assignment(): - if self.get_assignment() is ast: - return self - if self.get_assignment().get_parent(ast) is not None: - return self.get_assignment().get_parent(ast) + return [self.get_assignment()] + if self.is_function_call(): - if self.get_function_call() is ast: - return self - if self.get_function_call().get_parent(ast) is not None: - return self.get_function_call().get_parent(ast) + return [self.get_function_call()] + if self.is_declaration(): - if self.get_declaration() is ast: - return self - if self.get_declaration().get_parent(ast) is not None: - return self.get_declaration().get_parent(ast) + return [self.get_declaration()] + if self.is_return_stmt(): - if self.get_return_stmt() is ast: - return self - if self.get_return_stmt().get_parent(ast) is not None: - return self.get_return_stmt().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object - :type other: object - :return: True if equals, otherwise False. - :rtype: bool + return [self.get_return_stmt()] + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTSmallStmt): return False diff --git a/pynestml/meta_model/ast_stmt.py b/pynestml/meta_model/ast_stmt.py index 33e279e54..652ad48e8 100644 --- a/pynestml/meta_model/ast_stmt.py +++ b/pynestml/meta_model/ast_stmt.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Optional +from typing import List, Optional from pynestml.meta_model.ast_node import ASTNode @@ -74,27 +74,29 @@ def clone(self): return dup - def get_parent(self, ast: ASTNode=None) -> Optional[ASTNode]: - """ - Returns the parent node of a handed over AST object. - """ - if self.small_stmt is ast: - return self - if self.small_stmt is not None and self.small_stmt.get_parent(ast) is not None: - return self.small_stmt.get_parent(ast) - if self.compound_stmt is ast: - return self - if self.compound_stmt is not None and self.compound_stmt.get_parent(ast) is not None: - return self.compound_stmt.get_parent(ast) - return None - def is_small_stmt(self): return self.small_stmt is not None def is_compound_stmt(self): return self.compound_stmt is not None - def equals(self, other=None): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + if self.small_stmt: + return [self.small_stmt] + + if self.compound_stmt: + return [self.compound_stmt] + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. + """ if not isinstance(other, ASTStmt): return False if self.is_small_stmt() and other.is_small_stmt(): diff --git a/pynestml/meta_model/ast_unary_operator.py b/pynestml/meta_model/ast_unary_operator.py index 3d578294c..2871956cc 100644 --- a/pynestml/meta_model/ast_unary_operator.py +++ b/pynestml/meta_model/ast_unary_operator.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode @@ -73,23 +75,16 @@ def clone(self): return dup - def get_parent(self, ast): - """ - Indicates whether this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - return None + return [] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTUnaryOperator): return False diff --git a/pynestml/meta_model/ast_unit_type.py b/pynestml/meta_model/ast_unit_type.py index 57a6f88d4..346b9d7e4 100644 --- a/pynestml/meta_model/ast_unit_type.py +++ b/pynestml/meta_model/ast_unit_type.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_node import ASTNode from pynestml.utils.cloning_helpers import clone_numeric_literal @@ -177,44 +179,32 @@ def get_type_symbol(self): def set_type_symbol(self, type_symbol): self.type_symbol = type_symbol - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ if self.is_encapsulated: - if self.compound_unit is ast: - return self - if self.compound_unit.get_parent(ast) is not None: - return self.compound_unit.get_parent(ast) + return [self.compound_unit] if self.is_pow: - if self.base is ast: - return self - if self.base.get_parent(ast) is not None: - return self.base.get_parent(ast) + return [self.base] + if self.is_arithmetic_expression(): - if isinstance(self.get_lhs(), ASTUnitType): - if self.get_lhs() is ast: - return self - if self.get_lhs().get_parent(ast) is not None: - return self.get_lhs().get_parent(ast) - if self.get_rhs() is ast: - return self - if self.get_rhs().get_parent(ast) is not None: - return self.get_rhs().get_parent(ast) - return None - - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + children = [] + if self.get_lhs() and isinstance(self.get_lhs(), ASTNode): + children.append(self.get_lhs()) + + if self.get_rhs() and isinstance(self.get_rhs(), ASTNode): + children.append(self.get_rhs()) + + return children + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTUnitType): return False diff --git a/pynestml/meta_model/ast_update_block.py b/pynestml/meta_model/ast_update_block.py index 7b2188856..27c84fa1a 100644 --- a/pynestml/meta_model/ast_update_block.py +++ b/pynestml/meta_model/ast_update_block.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_block import ASTBlock from pynestml.meta_model.ast_node import ASTNode @@ -81,27 +83,16 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None + return [self.get_block()] - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTUpdateBlock): return False diff --git a/pynestml/meta_model/ast_variable.py b/pynestml/meta_model/ast_variable.py index 1de5e0489..ecb200d46 100644 --- a/pynestml/meta_model/ast_variable.py +++ b/pynestml/meta_model/ast_variable.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Any, Optional +from typing import Any, List, Optional from copy import copy @@ -184,14 +184,6 @@ def set_delay_parameter(self, delay: str): assert (delay is not None), '(PyNestML.AST.Variable) No delay parameter provided' self.delay_parameter = delay - def get_parent(self, ast: ASTNode) -> Optional[ASTNode]: - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :return: AST if this or one of the child nodes contains the handed over element. - """ - return None - def is_unit_variable(self) -> bool: r""" Provided on-the-fly information whether this variable represents a unit-variable, e.g., nS. @@ -203,19 +195,28 @@ def is_unit_variable(self) -> bool: return True return False - def equals(self, other: Any) -> bool: - r""" - The equals method. - :param other: a different object. - :return: True if equals, otherwise False. - """ - if not isinstance(other, ASTVariable): - return False - return self.get_name() == other.get_name() and self.get_differential_order() == other.get_differential_order() - def is_delay_variable(self) -> bool: """ Returns whether it is a delay variable or not :return: True if the variable has a delay parameter, False otherwise """ return self.get_delay_parameter() is not None + + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. + """ + if self.has_vector_parameter(): + return [self.get_vector_parameter()] + + return [] + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. + """ + if not isinstance(other, ASTVariable): + return False + + return self.get_name() == other.get_name() and self.get_differential_order() == other.get_differential_order() diff --git a/pynestml/meta_model/ast_while_stmt.py b/pynestml/meta_model/ast_while_stmt.py index 4069a2d8e..88d87eb22 100644 --- a/pynestml/meta_model/ast_while_stmt.py +++ b/pynestml/meta_model/ast_while_stmt.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + from pynestml.meta_model.ast_block import ASTBlock from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_node import ASTNode @@ -90,31 +92,23 @@ def get_block(self): """ return self.block - def get_parent(self, ast): - """ - Indicates whether a this node contains the handed over node. - :param ast: an arbitrary meta_model node. - :type ast: AST_ - :return: AST if this or one of the child nodes contains the handed over element. - :rtype: AST_ or None - """ - if self.get_condition() is ast: - return self - if self.get_condition().get_parent(ast) is not None: - return self.get_condition().get_parent(ast) - if self.get_block() is ast: - return self - if self.get_block().get_parent(ast) is not None: - return self.get_block().get_parent(ast) - return None - - def equals(self, other): + def get_children(self) -> List[ASTNode]: + r""" + Returns the children of this node, if any. + :return: List of children of this node. """ - The equals method. - :param other: a different object. - :type other: object - :return: True if equals, otherwise False. - :rtype: bool + children = [] + if self.get_condition(): + children.append(self.get_condition()) + + if self.get_block(): + children.append(self.get_block()) + + return children + + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTWhileStmt): return False diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index 3b852ab5d..5dd4aa3e0 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -38,6 +38,7 @@ from pynestml.utils.logger import Logger from pynestml.utils.logger import LoggingLevel from pynestml.utils.string_utils import removesuffix +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.visitors.ast_visitor import ASTVisitor @@ -172,7 +173,7 @@ def visit_function_call(self, node): found_parent_assignment = False node_ = node while not found_parent_assignment: - node_ = self.parent_node.get_parent(node_) + node_ = node_.get_parent() # XXX TODO also needs to accept normal ASTExpression, ASTAssignment? if isinstance(node_, ASTInlineExpression): found_parent_assignment = True @@ -225,6 +226,10 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): new_neuron = neuron.clone() new_synapse = synapse.clone() + new_neuron.parent_ = None # set root element + new_neuron.accept(ASTParentVisitor()) + new_synapse.parent_ = None # set root element + new_synapse.accept(ASTParentVisitor()) new_neuron.accept(ASTSymbolTableVisitor()) new_synapse.accept(ASTSymbolTableVisitor()) @@ -541,6 +546,8 @@ def mark_post_port(_expr=None): # add modified versions of neuron and synapse to list # + new_neuron.accept(ASTParentVisitor()) + new_synapse.accept(ASTParentVisitor()) ast_symbol_table_visitor = ASTSymbolTableVisitor() ast_symbol_table_visitor.after_ast_rewrite_ = True new_neuron.accept(ast_symbol_table_visitor) diff --git a/pynestml/transformers/synapse_remove_post_port.py b/pynestml/transformers/synapse_remove_post_port.py index 5a81dce3d..8284e4402 100644 --- a/pynestml/transformers/synapse_remove_post_port.py +++ b/pynestml/transformers/synapse_remove_post_port.py @@ -28,6 +28,7 @@ from pynestml.utils.logger import Logger, LoggingLevel from pynestml.transformers.transformer import Transformer from pynestml.utils.ast_utils import ASTUtils +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.frontend.frontend_configuration import FrontendConfiguration @@ -151,6 +152,11 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): # add modified versions of neuron and synapse to list # + new_neuron.parent_ = None # set root element + new_neuron.accept(ASTParentVisitor()) + new_synapse.parent_ = None # set root element + new_synapse.accept(ASTParentVisitor()) + new_neuron.accept(ASTSymbolTableVisitor()) new_synapse.accept(ASTSymbolTableVisitor()) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 94837a8b4..680d9635a 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -434,6 +434,10 @@ def create_internal_block(cls, model: ASTModel): ASTSourceLocation.get_added_source_position()) internal.update_scope(model.get_scope()) model.get_body().get_body_elements().append(internal) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + return model @classmethod @@ -449,6 +453,10 @@ def create_state_block(cls, model: ASTModel): state = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(), ASTSourceLocation.get_added_source_position()) model.get_body().get_body_elements().append(state) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + return model @classmethod @@ -654,24 +662,26 @@ def replace_var(_expr=None): if alternate_name: ast_ext_var.set_alternate_name(alternate_name) + ast_ext_var.parent_ = _expr + ast_ext_var.update_alt_scope(new_scope) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor ast_ext_var.accept(ASTSymbolTableVisitor()) if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): Logger.log_message(None, -1, "ASTSimpleExpression replacement made (var = " + str( - ast_ext_var.get_name()) + ") in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO) + ast_ext_var.get_name()) + ") in expression: " + str(_expr.get_parent()), None, LoggingLevel.INFO) _expr.set_variable(ast_ext_var) return if isinstance(_expr, ASTVariable): - if isinstance(node.get_parent(_expr), ASTAssignment): - node.get_parent(_expr).lhs = ast_ext_var + if isinstance(_expr.get_parent(), ASTAssignment): + _expr.get_parent().lhs = ast_ext_var Logger.log_message(None, -1, "ASTVariable replacement made in expression: " - + str(node.get_parent(_expr)), None, LoggingLevel.INFO) - elif isinstance(node.get_parent(_expr), ASTSimpleExpression) and node.get_parent(_expr).is_variable(): - node.get_parent(_expr).set_variable(ast_ext_var) - elif isinstance(node.get_parent(_expr), ASTDeclaration): + + str(_expr.get_parent()), None, LoggingLevel.INFO) + elif isinstance(_expr.get_parent(), ASTSimpleExpression) and _expr.get_parent().is_variable(): + _expr.get_parent().set_variable(ast_ext_var) + elif isinstance(_expr.get_parent(), ASTDeclaration): # variable could occur on the left-hand side; ignore. Only replace if it occurs on the right-hand side. pass else: @@ -680,12 +690,14 @@ def replace_var(_expr=None): raise Exception() return - p = node.get_parent(var) + p = var.get_parent() Logger.log_message(None, -1, "Error: unhandled use of variable " + var_name + " in expression " + str(p), None, LoggingLevel.INFO) raise Exception() node.accept(ASTHigherOrderVisitor(lambda x: replace_var(x))) + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + node.accept(ASTParentVisitor()) @classmethod def add_suffix_to_decl_lhs(cls, decl, suffix: str): @@ -758,7 +770,7 @@ def visit_function_call(self, node): found_parent_assignment = False node_ = node while not found_parent_assignment: - node_ = self.parent_node.get_parent(node_) + node_ = node_.get_parent() # XXX TODO also needs to accept normal ASTExpression, ASTAssignment? if isinstance(node_, ASTInlineExpression): found_parent_assignment = True @@ -806,6 +818,9 @@ def move_decls(cls, var_name, from_block, to_block, var_name_suffix: str, block_ decl.accept(ast_symbol_table_visitor) ast_symbol_table_visitor.block_type_stack.pop() + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + to_block.accept(ASTParentVisitor()) + return decls @classmethod @@ -1024,11 +1039,16 @@ def add_declaration_to_internals(cls, neuron: ASTModel, variable_name: str, init if vector_variable is not None: ast_declaration.set_size_parameter(vector_variable.get_vector_parameter()) neuron.add_to_internals_block(ast_declaration) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + neuron.accept(ASTParentVisitor()) + ast_declaration.update_scope(neuron.get_internals_blocks()[0].get_scope()) symtable_visitor = ASTSymbolTableVisitor() symtable_visitor.block_type_stack.push(BlockType.INTERNALS) ast_declaration.accept(symtable_visitor) symtable_visitor.block_type_stack.pop() + return neuron @classmethod @@ -1066,6 +1086,9 @@ def add_declaration_to_state_block(cls, neuron: ASTModel, variable: str, initial ast_declaration.set_size_parameter(vector_variable.get_vector_parameter()) neuron.add_to_state_block(ast_declaration) + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + neuron.accept(ASTParentVisitor()) + symtable_visitor = ASTSymbolTableVisitor() symtable_visitor.block_type_stack.push(BlockType.STATE) ast_declaration.accept(symtable_visitor) diff --git a/pynestml/utils/chan_info_enricher.py b/pynestml/utils/chan_info_enricher.py index e3f4c1459..61ff60f6f 100644 --- a/pynestml/utils/chan_info_enricher.py +++ b/pynestml/utils/chan_info_enricher.py @@ -19,11 +19,11 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.utils.model_parser import ModelParser -from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor import sympy from pynestml.utils.mechs_info_enricher import MechsInfoEnricher +from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor class ChanInfoEnricher(MechsInfoEnricher): diff --git a/pynestml/utils/logger.py b/pynestml/utils/logger.py index 1996a52c6..06e95b804 100644 --- a/pynestml/utils/logger.py +++ b/pynestml/utils/logger.py @@ -29,7 +29,6 @@ from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.messages import MessageCode from pynestml.meta_model.ast_inline_expression import ASTInlineExpression -from pynestml.meta_model.ast_input_port import ASTInputPort class LoggingLevel(Enum): diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py index 4ba67e849..c5514bff8 100644 --- a/pynestml/utils/mechs_info_enricher.py +++ b/pynestml/utils/mechs_info_enricher.py @@ -19,14 +19,16 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from collections import defaultdict + from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_utils import ASTUtils from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor -from pynestml.symbols.symbol import SymbolKind from pynestml.visitors.ast_visitor import ASTVisitor -from pynestml.symbols.predefined_functions import PredefinedFunctions -from collections import defaultdict -from pynestml.utils.ast_utils import ASTUtils class MechsInfoEnricher: @@ -118,6 +120,8 @@ def transform_ode_solutions(cls, neuron, mechs_info): mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + neuron.accept(ASTParentVisitor()) + return mechs_info @classmethod diff --git a/pynestml/utils/model_parser.py b/pynestml/utils/model_parser.py index f1a95f98d..7fabf361e 100644 --- a/pynestml/utils/model_parser.py +++ b/pynestml/utils/model_parser.py @@ -71,6 +71,7 @@ from pynestml.utils.messages import Messages from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor @@ -131,6 +132,11 @@ def parse_file(cls, file_path=None): ast_builder_visitor = ASTBuilderVisitor(stream.tokens) ast = ast_builder_visitor.visit(compilation_unit) + # create links back from children in the tree to their parents + for model in ast.get_model_list(): + model.parent_ = None # root node has no parent + model.accept(ASTParentVisitor()) + # create and update the corresponding symbol tables SymbolTable.initialize_symbol_table(ast.get_source_position()) for model in ast.get_model_list(): diff --git a/pynestml/visitors/ast_parent_visitor.py b/pynestml/visitors/ast_parent_visitor.py new file mode 100644 index 000000000..4e8f1886e --- /dev/null +++ b/pynestml/visitors/ast_parent_visitor.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# +# ast_parent_visitor.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.meta_model.ast_node import ASTNode +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ASTParentVisitor(ASTVisitor): + r""" + For each node in the AST, assign its ``parent_`` attribute; in other words, make the AST a doubly-linked tree. + """ + + def __init__(self): + super(ASTParentVisitor, self).__init__() + + def visit(self, node: ASTNode): + r"""Set ``parent_`` property on all children to refer back to this node.""" + children = node.get_children() + for child in children: + child.parent_ = node diff --git a/tests/invalid/CoCoVectorInNonVectorDeclaration.nestml b/tests/invalid/CoCoVectorInNonVectorDeclaration.nestml index 984d536ec..e068e8b54 100644 --- a/tests/invalid/CoCoVectorInNonVectorDeclaration.nestml +++ b/tests/invalid/CoCoVectorInNonVectorDeclaration.nestml @@ -36,5 +36,6 @@ model CoCoVectorInNonVectorDeclaration: state: g_ex [ten] mV = 10mV g_in mV = 10mV + g_ex + parameters: ten integer = 10 diff --git a/tests/test_symbol_table_builder.py b/tests/test_symbol_table_builder.py index ba8082212..718b09ec3 100644 --- a/tests/test_symbol_table_builder.py +++ b/tests/test_symbol_table_builder.py @@ -36,6 +36,7 @@ from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.utils.logger import Logger, LoggingLevel from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor @@ -78,6 +79,8 @@ def test_symbol_table_builder(self): SymbolTable.initialize_symbol_table(ast.get_source_position()) symbol_table_visitor = ASTSymbolTableVisitor() for model in ast.get_model_list(): + model.parent_ = None # set root element + model.accept(ASTParentVisitor()) model.accept(symbol_table_visitor) SymbolTable.add_model_scope(name=model.get_name(), scope=model.get_scope())