Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45938: Add support for generating IDs on data model objects #100

Merged
merged 8 commits into from
Aug 29, 2024
4 changes: 4 additions & 0 deletions docs/changes/DM-45938.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Added automatic ID generation for objects in Felis schemas when the `--id-generation` flag is included on the command line.
This is supported for the `create` and `validate` commands.
gpdf marked this conversation as resolved.
Show resolved Hide resolved

Also added a Schema validator that checks if index names are unique.
17 changes: 15 additions & 2 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,16 @@
envvar="FELIS_LOGFILE",
help="Felis log file path",
)
def cli(log_level: str, log_file: str | None) -> None:
@click.option(
"--id-generation", is_flag=True, help="Generate IDs for all objects that do not have them", default=False
)
@click.pass_context
def cli(ctx: click.Context, log_level: str, log_file: str | None, id_generation: bool) -> None:
"""Felis command line tools"""
ctx.ensure_object(dict)
ctx.obj["id_generation"] = id_generation
if ctx.obj["id_generation"]:
logger.info("ID generation is enabled")
if log_file:
logging.basicConfig(filename=log_file, level=log_level)
else:
Expand All @@ -88,7 +96,9 @@ def cli(log_level: str, log_file: str | None) -> None:
)
@click.option("--ignore-constraints", is_flag=True, help="Ignore constraints when creating tables")
@click.argument("file", type=click.File())
@click.pass_context
def create(
ctx: click.Context,
engine_url: str,
schema_name: str | None,
initialize: bool,
Expand Down Expand Up @@ -124,7 +134,7 @@ def create(
"""
try:
yaml_data = yaml.safe_load(file)
schema = Schema.model_validate(yaml_data)
schema = Schema.model_validate(yaml_data, context={"id_generation": ctx.obj["id_generation"]})
url = make_url(engine_url)
if schema_name:
logger.info(f"Overriding schema name with: {schema_name}")
Expand Down Expand Up @@ -355,7 +365,9 @@ def load_tap(
default=False,
)
@click.argument("files", nargs=-1, type=click.File())
@click.pass_context
def validate(
ctx: click.Context,
check_description: bool,
check_redundant_datatypes: bool,
check_tap_table_indexes: bool,
Expand Down Expand Up @@ -402,6 +414,7 @@ def validate(
"check_redundant_datatypes": check_redundant_datatypes,
"check_tap_table_indexes": check_tap_table_indexes,
"check_tap_principal": check_tap_principal,
"id_generation": ctx.obj["id_generation"],
},
)
except ValidationError as e:
Expand Down
84 changes: 84 additions & 0 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,55 @@ class Schema(BaseObject):
id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
"""Map of IDs to objects."""

@model_validator(mode="before")
@classmethod
def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
"""Generate IDs for objects that do not have them.
Parameters
----------
values
The values of the schema.
info
Validation context used to determine if ID generation is enabled.
Returns
-------
`dict` [ `str`, `Any` ]
The values of the schema with generated IDs.
"""
context = info.context
if not context or not context.get("id_generation", False):
logger.debug("Skipping ID generation")
return values
schema_name = values["name"]
if "@id" not in values:
values["@id"] = f"#{schema_name}"
logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'")
if "tables" in values:
for table in values["tables"]:
if "@id" not in table:
table["@id"] = f"#{table['name']}"
logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'")
if "columns" in table:
for column in table["columns"]:
if "@id" not in column:
column["@id"] = f"#{table['name']}.{column['name']}"
logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
if "constraints" in table:
for constraint in table["constraints"]:
if "@id" not in constraint:
constraint["@id"] = f"#{constraint['name']}"
logger.debug(
f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'"
)
if "indexes" in table:
for index in table["indexes"]:
if "@id" not in index:
index["@id"] = f"#{index['name']}"
logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'")
return values

@field_validator("tables", mode="after")
@classmethod
def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
Expand Down Expand Up @@ -777,6 +826,11 @@ def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
def check_unique_constraint_names(self: Schema) -> Schema:
"""Check for duplicate constraint names in the schema.
Returns
-------
`Schema`
The schema being validated.
Raises
------
ValueError
Expand All @@ -798,6 +852,36 @@ def check_unique_constraint_names(self: Schema) -> Schema:

return self

@model_validator(mode="after")
def check_unique_index_names(self: Schema) -> Schema:
gpdf marked this conversation as resolved.
Show resolved Hide resolved
"""Check for duplicate index names in the schema.
Returns
-------
`Schema`
The schema being validated.
Raises
------
ValueError
Raised if duplicate index names are found in the schema.
"""
index_names = set()
duplicate_names = []

for table in self.tables:
for index in table.indexes:
index_name = index.name
if index_name in index_names:
duplicate_names.append(index_name)
else:
index_names.add(index_name)

if duplicate_names:
raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")

return self

def _create_id_map(self: Schema) -> Schema:
"""Create a map of IDs to objects.
Expand Down
12 changes: 0 additions & 12 deletions tests/data/test-merge.yml

This file was deleted.

23 changes: 23 additions & 0 deletions tests/data/test_id_generation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: test_id_generation
description: Test schema for id generation.
tables:
- name: test_table
primaryKey: "#test_table.test_column1"
mysql:engine: MyISAM
columns:
- name: test_column1
datatype: int
description: Test column.
- name: test_column2
datatype: string
description: Test column.
length: 30
indexes:
- name: test_index
columns:
- test_column1
constraints:
- name: test_constraint
"@type": Unique
columns:
- test_column2
17 changes: 16 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

TESTDIR = os.path.abspath(os.path.dirname(__file__))
TEST_YAML = os.path.join(TESTDIR, "data", "test.yml")
TEST_MERGE_YAML = os.path.join(TESTDIR, "data", "test-merge.yml")


class CliTestCase(unittest.TestCase):
Expand Down Expand Up @@ -126,6 +125,22 @@ def test_validate_default(self) -> None:
result = runner.invoke(cli, ["validate", TEST_YAML], catch_exceptions=False)
self.assertEqual(result.exit_code, 0)

def test_id_generation(self) -> None:
"""Test the ``--id-generation`` flag."""
test_yaml = os.path.join(TESTDIR, "data", "test_id_generation.yaml")
runner = CliRunner()
result = runner.invoke(cli, ["--id-generation", "validate", test_yaml], catch_exceptions=False)
self.assertEqual(result.exit_code, 0)

def test_no_id_generation(self) -> None:
"""Test that loading a schema without IDs fails if ID generation is not
enabled.
"""
test_yaml = os.path.join(TESTDIR, "data", "test_id_generation.yaml")
runner = CliRunner()
result = runner.invoke(cli, ["validate", test_yaml], catch_exceptions=False)
self.assertNotEqual(result.exit_code, 0)

def test_validation_flags(self) -> None:
"""Test schema validation flags."""
runner = CliRunner()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,36 @@ def test_check_unique_constraint_names(self) -> None:
with self.assertRaises(ValidationError):
Schema(name="testSchema", id="#test_id", tables=[test_tbl])

def test_check_unique_index_names(self) -> None:
"""Test that index names are unique."""
test_col = Column(name="test_column1", id="#test_table#test_column1", datatype="int")
test_col2 = Column(name="test_column2", id="##test_table#test_column2", datatype="string", length=256)
test_tbl = Table(name="test_table", id="#test_table", columns=[test_col, test_col2])
test_idx = Index(name="idx_test", id="#idx_test", columns=[test_col.id])
test_idx2 = Index(name="idx_test", id="#idx_test2", columns=[test_col2.id])
test_tbl.indexes = [test_idx, test_idx2]
with self.assertRaises(ValidationError):
Schema(name="test_schema", id="#test-schema", tables=[test_tbl])

def test_model_validate(self) -> None:
"""Load a YAML test file and validate the schema data model."""
with open(TEST_YAML) as test_yaml:
data = yaml.safe_load(test_yaml)
Schema.model_validate(data)

def test_id_generation(self) -> None:
"""Test ID generation."""
test_path = os.path.join(TESTDIR, "data", "test_id_generation.yaml")
with open(test_path) as test_yaml:
yaml_data = yaml.safe_load(test_yaml)
# Generate IDs for objects in the test schema.
Schema.model_validate(yaml_data, context={"id_generation": True})
with open(test_path) as test_yaml:
yaml_data = yaml.safe_load(test_yaml)
# Test that an error is raised when id generation is disabled.
with self.assertRaises(ValidationError):
Schema.model_validate(yaml_data, context={"id_generation": False})


class SchemaVersionTest(unittest.TestCase):
"""Test the schema version."""
Expand Down
Loading