Skip to content

Commit

Permalink
add type annotations to Python code
Browse files Browse the repository at this point in the history
  • Loading branch information
berquist committed Oct 7, 2024
1 parent 5dc8ad5 commit 57b90a9
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 176 deletions.
5 changes: 3 additions & 2 deletions scripts/format-diff
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import re
from subprocess import check_output,STDOUT
from typing import List

choke_points = [
"ser &",
Expand All @@ -11,11 +12,11 @@ choke_points = [
commit = sys.argv[1]
paths = sys.argv[2:]

def getoutput(cmd_arr):
def getoutput(cmd_arr: List[str]) -> str:
result = check_output(cmd_arr,stderr=STDOUT,stdin=None).decode("utf-8").rstrip("\n")
return result

def format_diff(commit, path):
def format_diff(commit: str, path: str) -> None:
cmd = ["git", "diff", commit, "HEAD", path ]
diff_text = getoutput(cmd)

Expand Down
11 changes: 6 additions & 5 deletions src/sst/core/model/xmlToPython.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@

import xml.etree.ElementTree as ET
import sys, os, re
from typing import Dict

import sst

def printTree(indent, node):
def printTree(indent: int, node: ET.Element) -> None:
print("%sBegin %s: %r"%(' '*indent, node.tag, node.attrib))
if node.text and len(node.text.strip()):
print("%sText: %s"%(' '*indent, node.text.strip()))
Expand All @@ -27,7 +28,7 @@ def printTree(indent, node):


# Various global lookups
sstVars = dict()
sstVars: Dict[str, str] = dict()
sstParams = dict()
sstLinks = dict()

Expand Down Expand Up @@ -59,14 +60,14 @@ def replaceEnvVar(matchobj: re.Match) -> str:
return string


def getLink(name):
def getLink(name: str) -> sst.Link:
if name not in sstLinks:
sstLinks[name] = sst.Link(name)
return sstLinks[name]



def getParamName(node):
def getParamName(node: ET.Element) -> str:
name = node.tag.strip()
if name[0] == "{":
ns, tag = name[1:].split("}")
Expand Down Expand Up @@ -127,7 +128,7 @@ def buildComp(compNode: ET.Element) -> None:



def buildGraph(graph):
def buildGraph(graph: ET.Element) -> None:
for comp in graph.findall("component"):
buildComp(comp)

Expand Down
4 changes: 2 additions & 2 deletions src/sst/core/testingframework/sst_test_engine_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

################################################################################

def startup_and_run(sst_core_bin_dir, test_mode):
def startup_and_run(sst_core_bin_dir: str, test_mode: int) -> None:
""" This is the main entry point for loading and running the SST Test Frameworks
Engine.
Expand Down Expand Up @@ -138,7 +138,7 @@ def _generic_exception_handler(exc_e):

####

def _verify_test_frameworks_is_available(sst_core_frameworks_dir):
def _verify_test_frameworks_is_available(sst_core_frameworks_dir: str) -> None:
""" Ensure that all test framework files are available.
:param: sst_core_frameworks_dir = Dir of the test frameworks
"""
Expand Down
33 changes: 17 additions & 16 deletions src/sst/core/testingframework/sst_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import threading
import signal
import time
from typing import Optional

import test_engine_globals
from sst_unittest_support import *
Expand All @@ -40,7 +41,7 @@
#from test_engine_junit import junit_to_xml_report_string

if not sys.warnoptions:
import os, warnings
import warnings
warnings.simplefilter("once") # Change the filter in this process
os.environ["PYTHONWARNINGS"] = "once" # Also affect subprocesses

Expand All @@ -54,12 +55,12 @@ class SSTTestCase(unittest.TestCase):
basic resource for how to develop tests for this frameworks.
"""

def __init__(self, methodName):
def __init__(self, methodName: str) -> None:
# NOTE: __init__ is called at startup for all tests before any
# setUpModules(), setUpClass(), setUp() and the like are called.
super(SSTTestCase, self).__init__(methodName)
self.testname = methodName
parent_module_path = os.path.dirname(sys.modules[self.__class__.__module__].__file__)
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore
self._testsuite_dirpath = parent_module_path
#log_forced("SSTTestCase: __init__() - {0}".format(self.testname))
self.initializeClass(self.testname)
Expand All @@ -68,7 +69,7 @@ def __init__(self, methodName):

###

def initializeClass(self, testname):
def initializeClass(self, testname: str) -> None:
""" The method is called by the Frameworks immediately before class is
initialized.
Expand All @@ -92,7 +93,7 @@ def initializeClass(self, testname):

###

def setUp(self):
def setUp(self) -> None:
""" The method is called by the Frameworks immediately before a test is run
**NOTICE**:
Expand All @@ -115,7 +116,7 @@ def setUp(self):

###

def tearDown(self):
def tearDown(self) -> None:
""" The method is called by the Frameworks immediately after a test finishes
**NOTICE**:
Expand All @@ -136,7 +137,7 @@ def tearDown(self):
###

@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
""" This method is called by the Frameworks immediately before the TestCase starts
**NOTICE**:
Expand All @@ -154,7 +155,7 @@ def setUpClass(cls):
###

@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
""" This method is called by the Frameworks immediately after a TestCase finishes
**NOTICE**:
Expand All @@ -171,7 +172,7 @@ def tearDownClass(cls):

###

def get_testsuite_name(self):
def get_testsuite_name(self) -> str:
""" Return the testsuite (module) name
Returns:
Expand All @@ -181,7 +182,7 @@ def get_testsuite_name(self):

###

def get_testcase_name(self):
def get_testcase_name(self) -> str:
""" Return the testcase name
Returns:
Expand All @@ -190,7 +191,7 @@ def get_testcase_name(self):
return "{0}".format(strqual(self.__class__))
###

def get_testsuite_dir(self):
def get_testsuite_dir(self) -> str:
""" Return the directory path of the testsuite that is being run
Returns:
Expand All @@ -200,7 +201,7 @@ def get_testsuite_dir(self):

###

def get_test_output_run_dir(self):
def get_test_output_run_dir(self) -> str:
""" Return the path of the test output run directory
Returns:
Expand All @@ -210,7 +211,7 @@ def get_test_output_run_dir(self):

###

def get_test_output_tmp_dir(self):
def get_test_output_tmp_dir(self) -> str:
""" Return the path of the test tmp directory
Returns:
Expand All @@ -220,7 +221,7 @@ def get_test_output_tmp_dir(self):

###

def get_test_runtime_sec(self):
def get_test_runtime_sec(self) -> float:
""" Return the current runtime (walltime) of the test
Returns:
Expand Down Expand Up @@ -377,7 +378,7 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
### Module level support
################################################################################

def setUpModule():
def setUpModule() -> None:
""" Perform setup functions before the testing Module loads.
This function is called by the Frameworks before tests in any TestCase
Expand All @@ -400,7 +401,7 @@ def setUpModule():

###

def tearDownModule():
def tearDownModule() -> None:
""" Perform teardown functions immediately after a testing Module finishes.
This function is called by the Frameworks after all tests in all TestCases
Expand Down
Loading

0 comments on commit 57b90a9

Please sign in to comment.