Skip to content

Commit

Permalink
Use pathlib Paths; generate necessary base class files within header …
Browse files Browse the repository at this point in the history
…translator.
  • Loading branch information
tanaya-mankad committed Oct 4, 2024
1 parent 07f682b commit 5fb7479
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
4 changes: 2 additions & 2 deletions dodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def task_generate_cpp_code():
"file_dep": [schema.path for schema in example.cpp_schemas]
+ [schema.meta_schema_path for schema in example.schemas]
+ [CORE_SCHEMA_PATH, BASE_META_SCHEMA_PATH, Path(SOURCE_PATH, "header_entries.py"), Path(SOURCE_PATH, "cpp_entries.py")],
"targets": [schema.cpp_header_path for schema in example.cpp_schemas]
+ [schema.cpp_source_path for schema in example.cpp_schemas]
"targets": [schema.cpp_header_file_path for schema in example.cpp_schemas]
+ [schema.cpp_source_file_path for schema in example.cpp_schemas]
+ example.cpp_support_headers + [example.cpp_output_dir / "CMakeLists.txt", example.cpp_output_dir / "src" / "CMakeLists.txt"],
"actions": [(example.generate_cpp_project, [["https://github.com/nlohmann/json.git", "https://github.com/bigladder/courier.git", "https://github.com/fmtlib/fmt.git"]])],
"clean": True,
Expand Down
6 changes: 4 additions & 2 deletions lattice/cpp_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,10 @@ def _get_items_to_serialize(self, header_tree):
# Shortcut to avoid creating "from_json" entries for the main class, but create them
# for all other classes. The main class relies on an "Initialize" function instead,
# dealt-with in the next block with function overrides.
if isinstance(entry, Struct) and entry.name not in self._namespace._name:
# Create the "from_json" function definition (header)
if (isinstance(entry, Struct) and
entry.name not in self._namespace._name and
len([c for c in entry.child_entries if isinstance(c, DataElement)])):
# Create the "from_json" function definition (header), only if it won't be empty
s = StructSerialization(entry.name, self._namespace)
for data_element_entry in [c for c in entry.child_entries if isinstance(c, DataElement)]:
# In function body, create each "get_to" for individual data elements
Expand Down
50 changes: 27 additions & 23 deletions lattice/header_entries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import lattice.cpp.support_files as support
from .file_io import load, get_base_stem
from .util import snake_style, hyphen_separated_lowercase_style
from typing import Optional
Expand Down Expand Up @@ -392,7 +393,7 @@ def __init__(self, f_ret, f_name, f_args, name, parent):
@property
def value(self):
tab = "\t"
return f"{self._level * tab}{' '.join([self.ret_type, self.fname, self.args])}{self._closure}"
return f"{self._level * tab}{' '.join([self.ret_type, self.fname])}{self.args}{self._closure}"


# -------------------------------------------------------------------------------------------------
Expand All @@ -412,7 +413,7 @@ def __init__(self, name, parent):
class InitializeFunction(FunctionalHeaderEntry):

def __init__(self, name, parent):
super().__init__("void", "initialize", "(const nlohmann::json& j)", name, parent)
super().__init__("virtual void", "initialize", "(const nlohmann::json& j)", name, parent)


# # -------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -499,11 +500,15 @@ def modified_insertion_sort(obj_list):
return swapped

# fmt: off
def translate(self, input_file_path, top_namespace: str, forward_declarations_path: pathlib.Path):
def translate(self,
input_file_path: pathlib.Path,
forward_declarations_path: pathlib.Path,
output_path: pathlib.Path,
top_namespace: str):
"""X"""
self._source_dir = os.path.dirname(os.path.abspath(input_file_path))
self._source_dir = input_file_path.parent.resolve()
self._forward_declaration_dir = forward_declarations_path
self._schema_name = os.path.splitext(os.path.splitext(os.path.basename(input_file_path))[0])[0]
self._schema_name = get_base_stem(input_file_path)
self._references.clear()
self._derived_types.clear()
self._fundamental_data_types.clear()
Expand Down Expand Up @@ -553,9 +558,9 @@ def translate(self, input_file_path, top_namespace: str, forward_declarations_pa
self._namespace,
superclass=self._contents[base_level_tag].get("Data Group Template", ""),
)
self._add_header_dependencies(s)
self._add_header_dependencies(s, output_path)
# When there is a base class, add overrides:
# self._add_function_overrides(s, self._fundamental_base_class)
self._add_function_overrides(s, output_path, self._contents[base_level_tag].get("Data Group Template", ""))

# elif self._contents[base_level_tag].get('Object Type') == 'Grid Variables':
# s = Struct(base_level_tag, self._namespace, superclass='GridVariablesBase')
Expand Down Expand Up @@ -584,7 +589,7 @@ def translate(self, input_file_path, top_namespace: str, forward_declarations_pa
self._references,
self._search_nodes_for_datatype,
)
self._add_header_dependencies(d)
self._add_header_dependencies(d, output_path)
for data_element in self._contents[base_level_tag]["Data Elements"]:
d = DataIsSetElement(data_element, s)
for data_element in self._contents[base_level_tag]["Data Elements"]:
Expand Down Expand Up @@ -627,12 +632,12 @@ def _load_meta_info(self, schema_section):
"""Store the global/common types and the types defined by any named references."""
self._root_data_group = schema_section.get("Root Data Group")
refs: dict = {
f"{self._schema_name}": os.path.join(self._source_dir, f"{self._schema_name}.schema.yaml"),
"core": os.path.join(os.path.dirname(__file__), "core.schema.yaml"),
f"{self._schema_name}": self._source_dir / f"{self._schema_name}.schema.yaml",
"core": pathlib.Path(__file__).with_name("core.schema.yaml"),
}
if "References" in schema_section:
for ref in schema_section["References"]:
refs.update({f"{ref}": os.path.join(self._source_dir, ref + ".schema.yaml")})
refs.update({f"{ref}": self._source_dir / f"{ref}.schema.yaml"})
if (self._schema_name == "core" and
self._forward_declaration_dir and
self._forward_declaration_dir.is_dir()):
Expand Down Expand Up @@ -682,47 +687,46 @@ def _add_standard_dependency_headers(self, ref_list):
]
)

def _add_header_dependencies(self, data_element):
def _add_header_dependencies(self, data_element, generated_header_path: pathlib.Path):
"""Extract the dependency name from the data_element's type for included headers."""
if "core_ns" in data_element.type:
self._add_member_includes("core")
if "unique_ptr" in data_element.type:
m = re.search(r"\<(?P<base_class_type>.*)\>", data_element.type)
if m:
self._add_member_includes(m.group("base_class_type"), True)
self._add_member_includes(m.group("base_class_type"), generated_header_path)
if data_element.superclass:
self._add_member_includes(data_element.superclass, True)
self._add_member_includes(data_element.superclass, generated_header_path)
for external_source in data_element.external_reference_sources:
# This piece captures any "forward-declared" types that need to be
# processed by the DataElement type-finding mechanism before their header is known.
self._add_member_includes(external_source)

def _add_member_includes(self, dependency: str, base_class: bool = False):
def _add_member_includes(self, dependency: str, generated_base_class_path: Optional[pathlib.Path] = None):
"""
Add the dependency to the list of included headers,
and to the list of base classes if necessary.
"""
header_include = f"#include <{hyphen_separated_lowercase_style(dependency)}.h>"
if header_include not in self._preamble:
self._preamble.append(header_include)
if base_class:
self._required_base_classes.append(dependency)
if generated_base_class_path:
#self._required_base_classes.append(dependency)
support.generate_superclass_header(dependency, generated_base_class_path)

# fmt: off
def _add_function_overrides(self, parent_node, base_class_name):
def _add_function_overrides(self, parent_node, output_path, base_class_name):
"""Get base class virtual functions to be overridden."""
base_class = os.path.join(
os.path.dirname(__file__), "src", f"{hyphen_separated_lowercase_style(base_class_name)}.h"
)
base_class = pathlib.Path(output_path) / f"{hyphen_separated_lowercase_style(base_class_name)}.h"
try:
with open(base_class) as b:
for line in b:
if base_class_name not in line:
m = re.match(r"\s*virtual\s(?P<return_type>.*)\s(?P<name>.*)\((?P<arguments>.*)\)", line)
m = re.search(r"\s*virtual\s(?P<return_type>.*)\s(?P<name>.*)\((?P<arguments>.*)\)", line)
if m:
MemberFunctionOverride(m.group("return_type"),
m.group("name"),
f'({m.group("argument")})',
f'({m.group("arguments")})',
"",
parent_node)
except:
Expand Down
26 changes: 13 additions & 13 deletions lattice/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ def json_schema_path(self, json_schema_path):
self._json_schema_path = Path(json_schema_path).absolute()

@property
def cpp_header_path(self): # pylint:disable=C0116
def cpp_header_file_path(self): # pylint:disable=C0116
return self._cpp_header_path

@cpp_header_path.setter
def cpp_header_path(self, value):
@cpp_header_file_path.setter
def cpp_header_file_path(self, value):
self._cpp_header_path = Path(value).absolute()

@property
def cpp_source_path(self): # pylint:disable=C0116
def cpp_source_file_path(self): # pylint:disable=C0116
return self._cpp_source_path

@cpp_source_path.setter
def cpp_source_path(self, value):
@cpp_source_file_path.setter
def cpp_source_file_path(self, value):
self._cpp_source_path = Path(value).absolute()


Expand Down Expand Up @@ -334,8 +334,8 @@ def setup_cpp_source_files(self):
self._cpp_output_include_dir = make_dir(include_dir / f"{self.root_directory.name}")
self._cpp_output_src_dir = make_dir(self.cpp_output_dir / "src")
for schema in self.cpp_schemas:
schema.cpp_header_path = self._cpp_output_include_dir / f"{schema.file_base_name.lower()}.h"
schema.cpp_source_path = self._cpp_output_src_dir / f"{schema.file_base_name.lower()}.cpp"
schema.cpp_header_file_path = self._cpp_output_include_dir / f"{schema.file_base_name.lower()}.h"
schema.cpp_source_file_path = self._cpp_output_src_dir / f"{schema.file_base_name.lower()}.cpp"

def setup_cpp_repository(self, submodules: list[str]):
"""Initialize the CPP output directory as a Git repo."""
Expand All @@ -362,12 +362,12 @@ def generate_cpp_project(self, submodules: list[str]):
h = HeaderTranslator()
c = CPPTranslator()
for schema in self.cpp_schemas:
h.translate(schema.path, self.root_directory.name, self.schema_directory_path)
dump(str(h), schema.cpp_header_path)
h.translate(schema.path, self.schema_directory_path, self._cpp_output_include_dir, self.root_directory.name)
dump(str(h), schema.cpp_header_file_path)
c.translate(self.root_directory.name, h)
dump(str(c), schema.cpp_source_path)
dump(str(c), schema.cpp_source_file_path)
self.setup_cpp_repository(submodules)
support.render_support_headers(self.root_directory.name, self._cpp_output_include_dir)
support.render_build_files(self.root_directory.name, submodules, self.cpp_output_dir)
for superclass in h.required_base_classes:
support.generate_superclass_header(superclass, self._cpp_output_include_dir)
# for superclass in h.required_base_classes:
# support.generate_superclass_header(superclass, self._cpp_output_include_dir)

0 comments on commit 5fb7479

Please sign in to comment.