Skip to content

Commit

Permalink
Refactor ASTNode.get_parent() to improve runtime performance (#1049)
Browse files Browse the repository at this point in the history
* refactor ASTNode.get_parent() to improve runtime performance

---------

Co-authored-by: C.A.P. Linssen <[email protected]>
  • Loading branch information
clinssen and C.A.P. Linssen authored May 21, 2024
1 parent de1ee06 commit 076bd89
Show file tree
Hide file tree
Showing 60 changed files with 823 additions and 1,003 deletions.
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_all_variables_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# check if it is part of an invariant
# if it is the case, there is no "recursive" declaration
# so check if the parent is a declaration and the expression the invariant
expr_par = node.get_parent(expr)
expr_par = expr.get_parent()
if isinstance(expr_par, ASTDeclaration) and expr_par.get_invariant() == expr:
# in this case its ok if it is recursive or defined later on
continue
Expand Down
4 changes: 2 additions & 2 deletions pynestml/cocos/co_co_no_kernels_except_in_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def visit_variable(self, node: ASTNode):
if not symbol.is_kernel():
continue
if node.get_complete_name() == kernelName:
parent = self.__neuron_node.get_parent(node)
parent = node.get_parent()
if parent is not None:
if isinstance(parent, ASTKernel):
continue
grandparent = self.__neuron_node.get_parent(parent)
grandparent = parent.get_parent()
if grandparent is not None and isinstance(grandparent, ASTFunctionCall):
grandparent_func_name = grandparent.get_name()
if grandparent_func_name == 'convolve':
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_resolution_func_legally_used.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def visit_simple_expression(self, node):
if function_name == PredefinedFunctions.TIME_RESOLUTION:
_node = node
while _node:
_node = self.neuron.get_parent(_node)
_node = _node.get_parent()

if isinstance(_node, ASTEquationsBlock) \
or isinstance(_node, ASTFunction):
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_simple_delta_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_co_co(cls, model: ASTModel):
def check_simple_delta(_expr=None):
if _expr.is_function_call() and _expr.get_function_call().get_name() == "delta":
deltafunc = _expr.get_function_call()
parent = model.get_parent(_expr)
parent = _expr.get_parent()

# check the argument
if not (len(deltafunc.get_args()) == 1
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_vector_declaration_right_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VectorDeclarationVisitor(ASTVisitor):
def visit_variable(self, node: ASTVariable):
vector_parameter = node.get_vector_parameter()
if vector_parameter is not None:
if isinstance(self._neuron.get_parent(node), ASTDeclaration):
if isinstance(node.get_parent(), ASTDeclaration):
# node is being declared: size should be >= 1
min_index = 1

Expand Down
20 changes: 9 additions & 11 deletions pynestml/meta_model/ast_arithmetic_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List
from pynestml.meta_model.ast_node import ASTNode


Expand Down Expand Up @@ -71,23 +72,20 @@ def clone(self):

return dup

def get_parent(self, ast):
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
"""
return None
return []

def equals(self, other):
# type: (ASTNode) -> bool
"""
def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTArithmeticOperator):
return False

return (self.is_times_op == other.is_times_op and self.is_div_op == other.is_div_op
and self.is_modulo_op == other.is_modulo_op and self.is_plus_op == other.is_plus_op
and self.is_minus_op == other.is_minus_op and self.is_pow_op == other.is_pow_op)
63 changes: 18 additions & 45 deletions pynestml/meta_model/ast_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import Optional
from typing import List, Optional

from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_variable import ASTVariable
Expand Down Expand Up @@ -124,34 +124,27 @@ def get_expression(self):
"""
return self.rhs

def get_parent(self, ast):
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
"""
if self.get_variable() is ast:
return self
if self.get_expression() is ast:
return self
if self.get_variable().get_parent(ast) is not None:
return self.get_variable().get_parent(ast)
if self.get_expression().get_parent(ast) is not None:
return self.get_expression().get_parent(ast)
return None

def equals(self, other):
"""
The equals operation.
:param other: a different object.
:type other: object
:return: True if equal, otherwise False.
:rtype: bool
children = []
if self.get_variable():
children.append(self.get_variable())

if self.get_expression():
children.append(self.get_expression())

return children

def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTAssignment):
return False

return (self.get_variable().equals(other.get_variable())
and self.is_compound_quotient == other.is_compound_quotient
and self.is_compound_product == other.is_compound_product
Expand All @@ -160,26 +153,6 @@ def equals(self, other):
and self.is_direct_assignment == other.is_direct_assignment
and self.get_expression().equals(other.get_expression()))

def deconstruct_compound_assignment(self):
"""
From lhs and rhs it constructs a new expression which corresponds to direct assignment.
E.g.: a += b*c -> a = a + b*c
:return: the rhs for an equivalent direct assignment.
:rtype: ast_expression
"""
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
# TODO: get rid of this through polymorphism?
assert not self.is_direct_assignment, "Can only be invoked on a compound assignment."

operator = self.extract_operator_from_compound_assignment()
lhs_variable = self.get_lhs_variable_as_expression()
rhs_in_brackets = self.get_bracketed_rhs_expression()
result = self.construct_equivalent_direct_assignment_rhs(operator, lhs_variable, rhs_in_brackets)
# create symbols for the new Expression:
visitor = ASTSymbolTableVisitor()
result.accept(visitor)
return result

def get_lhs_variable_as_expression(self):
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
# TODO: maybe calculate new source positions exactly?
Expand Down
25 changes: 10 additions & 15 deletions pynestml/meta_model/ast_bit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List

from pynestml.meta_model.ast_node import ASTNode


Expand Down Expand Up @@ -82,23 +84,16 @@ def clone(self):

return dup

def get_parent(self, ast):
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
return None
return []

def equals(self, other):
"""
The equals method.
:param other: a different object.
:type other: object
:return: True if equal, otherwise False.
:rtype: bool
def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTBitOperator):
return False
Expand Down
32 changes: 11 additions & 21 deletions pynestml/meta_model/ast_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List

from pynestml.meta_model.ast_node import ASTNode


Expand Down Expand Up @@ -95,28 +97,16 @@ def delete_stmt(self, stmt):
"""
self.stmts.remove(stmt)

def get_parent(self, ast):
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
"""
for stmt in self.get_stmts():
if stmt is ast:
return self
if stmt.get_parent(ast) is not None:
return stmt.get_parent(ast)
return None

def equals(self, other):
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
The equals method.
:param other: a different object.
:type other: object
:return: True if equal, otherwise False.
:rtype: bool
return self.get_stmts()

def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTBlock):
return False
Expand Down
30 changes: 10 additions & 20 deletions pynestml/meta_model/ast_block_with_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List

from pynestml.meta_model.ast_node import ASTNode


Expand Down Expand Up @@ -112,28 +114,16 @@ def clear(self):
del self.declarations
self.declarations = list()

def get_parent(self, ast=None):
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
for stmt in self.get_declarations():
if stmt is ast:
return self
if stmt.get_parent(ast) is not None:
return stmt.get_parent(ast)
return None
return self.get_declarations()

def equals(self, other=None):
"""
The equals method.
:param other: a different object.
:type other: object
:return: True if equal, otherwise False
:rtype: bool
def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTBlockWithVariables):
return False
Expand Down
26 changes: 11 additions & 15 deletions pynestml/meta_model/ast_comparison_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List

from pynestml.meta_model.ast_node import ASTNode


Expand Down Expand Up @@ -94,26 +96,20 @@ def clone(self):

return dup

def get_parent(self, ast):
"""
Indicates whether a this node contains the handed over node.
:param ast: an arbitrary meta_model node.
:type ast: AST_
:return: AST if this or one of the child nodes contains the handed over element.
:rtype: AST_ or None
def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
:return: List of children of this node.
"""
return None
return []

def equals(self, other):
"""
The equals method.
:param other: a different object.
:type other: object
:return: True if equal, otherwise False.
:rtype: bool
def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTComparisonOperator):
return False

return (self.is_lt == other.is_lt and self.is_le == other.is_le
and self.is_eq == other.is_eq and self.is_ne == other.is_ne
and self.is_ne2 == other.is_ne2 and self.is_ge == other.is_ge and self.is_gt == other.is_gt)
Loading

0 comments on commit 076bd89

Please sign in to comment.