From 42f42868dab14988c0afcc511d6a43586f169ebd Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Fri, 18 Aug 2023 08:36:05 -0700 Subject: [PATCH] Add linked nodes for dimension to python client --- .../python/datajunction/__about__.py | 2 +- .../python/datajunction/_internal.py | 10 +++ .../python/datajunction/builder.py | 89 ++----------------- .../python/datajunction/client.py | 86 ++++++++++++++++++ .../python/datajunction/nodes.py | 9 ++ .../python/tests/test_builder.py | 3 + .../python/tests/test_client.py | 21 +++++ 7 files changed, 135 insertions(+), 85 deletions(-) diff --git a/datajunction-clients/python/datajunction/__about__.py b/datajunction-clients/python/datajunction/__about__.py index 9d2f03de2..d0996dbb6 100644 --- a/datajunction-clients/python/datajunction/__about__.py +++ b/datajunction-clients/python/datajunction/__about__.py @@ -1,4 +1,4 @@ """ Version for Hatch """ -__version__ = "0.0.1a15" +__version__ = "0.0.1a18" diff --git a/datajunction-clients/python/datajunction/_internal.py b/datajunction-clients/python/datajunction/_internal.py index 3ac374f9a..297f65f49 100644 --- a/datajunction-clients/python/datajunction/_internal.py +++ b/datajunction-clients/python/datajunction/_internal.py @@ -407,6 +407,16 @@ def _set_column_attributes( ) return response.json() + def _find_nodes_with_dimension( + self, + node_name, + ): + """ + Find all nodes with this dimension + """ + response = self._session.get(f"/dimensions/{node_name}/nodes/") + return response.json() + class ClientEntity(BaseModel): """ diff --git a/datajunction-clients/python/datajunction/builder.py b/datajunction-clients/python/datajunction/builder.py index 3120e2a0e..19af91045 100644 --- a/datajunction-clients/python/datajunction/builder.py +++ b/datajunction-clients/python/datajunction/builder.py @@ -101,21 +101,6 @@ def restore_node(self, node_name: str) -> None: # # Nodes: SOURCE # - def source(self, node_name: str) -> "Source": - """ - Retrieves a source node with that name if one exists. - """ - node_dict = self._verify_node_exists( - node_name, - type_=models.NodeType.SOURCE.value, - ) - node = Source( - **node_dict, - dj_client=self, - ) - node.primary_key = self._primary_key_from_columns(node_dict["columns"]) - return node - def create_source( # pylint: disable=too-many-arguments self, name: str, @@ -145,6 +130,7 @@ def create_source( # pylint: disable=too-many-arguments columns=columns, ) self._create_node(node=new_node, mode=mode) + new_node.refresh() return new_node def register_table(self, catalog: str, schema: str, table: str) -> Source: @@ -162,21 +148,6 @@ def register_table(self, catalog: str, schema: str, table: str) -> Source: # # Nodes: TRANSFORM # - def transform(self, node_name: str) -> "Transform": - """ - Retrieves a transform node with that name if one exists. - """ - node_dict = self._verify_node_exists( - node_name, - type_=models.NodeType.TRANSFORM.value, - ) - node = Transform( - **node_dict, - dj_client=self, - ) - node.primary_key = self._primary_key_from_columns(node_dict["columns"]) - return node - def create_transform( # pylint: disable=too-many-arguments self, name: str, @@ -200,26 +171,12 @@ def create_transform( # pylint: disable=too-many-arguments query=query, ) self._create_node(node=new_node, mode=mode) + new_node.refresh() return new_node # # Nodes: DIMENSION # - def dimension(self, node_name: str) -> "Dimension": - """ - Retrieves a Dimension node with that name if one exists. - """ - node_dict = self._verify_node_exists( - node_name, - type_=models.NodeType.DIMENSION.value, - ) - node = Dimension( - **node_dict, - dj_client=self, - ) - node.primary_key = self._primary_key_from_columns(node_dict["columns"]) - return node - def create_dimension( # pylint: disable=too-many-arguments self, name: str, @@ -243,26 +200,12 @@ def create_dimension( # pylint: disable=too-many-arguments query=query, ) self._create_node(node=new_node, mode=mode) + new_node.refresh() return new_node # # Nodes: METRIC # - def metric(self, node_name: str) -> "Metric": - """ - Retrieves a Metric node with that name if one exists. - """ - node_dict = self._verify_node_exists( - node_name, - type_=models.NodeType.METRIC.value, - ) - node = Metric( - **node_dict, - dj_client=self, - ) - node.primary_key = self._primary_key_from_columns(node_dict["columns"]) - return node - def create_metric( # pylint: disable=too-many-arguments self, name: str, @@ -286,35 +229,12 @@ def create_metric( # pylint: disable=too-many-arguments query=query, ) self._create_node(node=new_node, mode=mode) + new_node.refresh() return new_node # # Nodes: CUBE # - def cube(self, node_name: str) -> "Cube": # pragma: no cover - """ - Retrieves a Cube node with that name if one exists. - """ - node_dict = self._get_cube(node_name) - if "name" not in node_dict: - raise DJClientException(f"Cube `{node_name}` does not exist") - dimensions = [ - f'{col["node_name"]}.{col["name"]}' - for col in node_dict["cube_elements"] - if col["type"] != "metric" - ] - metrics = [ - f'{col["node_name"]}.{col["name"]}' - for col in node_dict["cube_elements"] - if col["type"] == "metric" - ] - return Cube( - **node_dict, - metrics=metrics, - dimensions=dimensions, - dj_client=self, - ) - def create_cube( # pylint: disable=too-many-arguments self, name: str, @@ -338,4 +258,5 @@ def create_cube( # pylint: disable=too-many-arguments display_name=display_name, ) self._create_node(node=new_node, mode=mode) # pragma: no cover + new_node.refresh() return new_node # pragma: no cover diff --git a/datajunction-clients/python/datajunction/client.py b/datajunction-clients/python/datajunction/client.py index 7f4da3024..ff5839179 100644 --- a/datajunction-clients/python/datajunction/client.py +++ b/datajunction-clients/python/datajunction/client.py @@ -8,6 +8,7 @@ from datajunction import _internal, models from datajunction.exceptions import DJClientException +from datajunction.nodes import Cube, Dimension, Metric, Source, Transform class DJClient(_internal.DJClient): @@ -280,3 +281,88 @@ def list_engines(self) -> List[dict]: {"name": engine["name"], "version": engine["version"]} for engine in json_response ] + + # Read nodes + def source(self, node_name: str) -> Source: + """ + Retrieves a source node with that name if one exists. + """ + node_dict = self._verify_node_exists( + node_name, + type_=models.NodeType.SOURCE.value, + ) + node = Source( + **node_dict, + dj_client=self, + ) + node.primary_key = self._primary_key_from_columns(node_dict["columns"]) + return node + + def transform(self, node_name: str) -> Transform: + """ + Retrieves a transform node with that name if one exists. + """ + node_dict = self._verify_node_exists( + node_name, + type_=models.NodeType.TRANSFORM.value, + ) + node = Transform( + **node_dict, + dj_client=self, + ) + node.primary_key = self._primary_key_from_columns(node_dict["columns"]) + return node + + def dimension(self, node_name: str) -> "Dimension": + """ + Retrieves a Dimension node with that name if one exists. + """ + node_dict = self._verify_node_exists( + node_name, + type_=models.NodeType.DIMENSION.value, + ) + node = Dimension( + **node_dict, + dj_client=self, + ) + node.primary_key = self._primary_key_from_columns(node_dict["columns"]) + return node + + def metric(self, node_name: str) -> "Metric": + """ + Retrieves a Metric node with that name if one exists. + """ + node_dict = self._verify_node_exists( + node_name, + type_=models.NodeType.METRIC.value, + ) + node = Metric( + **node_dict, + dj_client=self, + ) + node.primary_key = self._primary_key_from_columns(node_dict["columns"]) + return node + + def cube(self, node_name: str) -> "Cube": # pragma: no cover + """ + Retrieves a Cube node with that name if one exists. + """ + node_dict = self._get_cube(node_name) + if "name" not in node_dict: + raise DJClientException(f"Cube `{node_name}` does not exist") + dimensions = [ + f'{col["node_name"]}.{col["name"]}' + for col in node_dict["cube_elements"] + if col["type"] != "metric" + ] + metrics = [ + f'{col["node_name"]}.{col["name"]}' + for col in node_dict["cube_elements"] + if col["type"] == "metric" + ] + return Cube( + **node_dict, + metrics=metrics, + dimensions=dimensions, + dj_client=self, + ) diff --git a/datajunction-clients/python/datajunction/nodes.py b/datajunction-clients/python/datajunction/nodes.py index e13d34874..722cb36aa 100644 --- a/datajunction-clients/python/datajunction/nodes.py +++ b/datajunction-clients/python/datajunction/nodes.py @@ -327,6 +327,15 @@ class Dimension(NodeWithQuery): query: str columns: Optional[List[models.Column]] + def linked_nodes(self): + """ + Find all nodes linked to this dimension + """ + return [ + node["name"] + for node in self.dj_client._find_nodes_with_dimension(self.name) + ] + class Cube(Node): # pylint: disable=abstract-method """ diff --git a/datajunction-clients/python/tests/test_builder.py b/datajunction-clients/python/tests/test_builder.py index 52bf466a1..c9117af87 100644 --- a/datajunction-clients/python/tests/test_builder.py +++ b/datajunction-clients/python/tests/test_builder.py @@ -378,6 +378,7 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument mode=NodeMode.PUBLISHED, ) assert account_type_dim.name == "default.account_type" + assert len(account_type_dim.columns) == 3 assert "default.account_type" in client.list_dimensions(namespace="default") # transform nodes @@ -395,6 +396,8 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument "default.large_revenue_payments_only" in client.namespace("default").transforms() ) + assert len(large_revenue_payments_only.columns) == 4 + client.transform("default.large_revenue_payments_only") result = large_revenue_payments_only.add_materialization( diff --git a/datajunction-clients/python/tests/test_client.py b/datajunction-clients/python/tests/test_client.py index d50cc5731..afe5bceaa 100644 --- a/datajunction-clients/python/tests/test_client.py +++ b/datajunction-clients/python/tests/test_client.py @@ -280,9 +280,30 @@ def test_list_nodes(self, client): "foo.bar.repair_orders_thin", ] + def test_find_nodes_with_dimension(self, client): + """ + Check that `dimension.linked_nodes()` works as expected. + """ + repair_order_dim = client.dimension("default.repair_order") + assert repair_order_dim.linked_nodes() == [ + "default.repair_order_details", + "default.avg_repair_price", + "default.total_repair_cost", + "default.total_repair_order_discounts", + "default.avg_repair_order_discounts", + ] + # # Get common metrics and dimensions # + def test_common_dimensions(self, client): + """ + Test that getting common dimensions for metrics works + """ + dims = client.common_dimensions( + metrics=["default.num_repair_orders", "default.avg_repair_price"], + ) + assert len(dims) == 8 # # SQL and data