Skip to content

Commit

Permalink
Merge pull request #731 from shangyian/add-linked-nodes-client
Browse files Browse the repository at this point in the history
  • Loading branch information
shangyian authored Aug 18, 2023
2 parents ab7f441 + 42f4286 commit 17d25f4
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 85 deletions.
2 changes: 1 addition & 1 deletion datajunction-clients/python/datajunction/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
Version for Hatch
"""
__version__ = "0.0.1a15"
__version__ = "0.0.1a18"
10 changes: 10 additions & 0 deletions datajunction-clients/python/datajunction/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ def _set_column_attributes(
)
return response.json()

def _find_nodes_with_dimension(
self,
node_name,
):
"""
Find all nodes with this dimension
"""
response = self._session.get(f"/dimensions/{node_name}/nodes/")
return response.json()


class ClientEntity(BaseModel):
"""
Expand Down
89 changes: 5 additions & 84 deletions datajunction-clients/python/datajunction/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,6 @@ def restore_node(self, node_name: str) -> None:
#
# Nodes: SOURCE
#
def source(self, node_name: str) -> "Source":
"""
Retrieves a source node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.SOURCE.value,
)
node = Source(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def create_source( # pylint: disable=too-many-arguments
self,
name: str,
Expand Down Expand Up @@ -145,6 +130,7 @@ def create_source( # pylint: disable=too-many-arguments
columns=columns,
)
self._create_node(node=new_node, mode=mode)
new_node.refresh()
return new_node

def register_table(self, catalog: str, schema: str, table: str) -> Source:
Expand All @@ -162,21 +148,6 @@ def register_table(self, catalog: str, schema: str, table: str) -> Source:
#
# Nodes: TRANSFORM
#
def transform(self, node_name: str) -> "Transform":
"""
Retrieves a transform node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.TRANSFORM.value,
)
node = Transform(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def create_transform( # pylint: disable=too-many-arguments
self,
name: str,
Expand All @@ -200,26 +171,12 @@ def create_transform( # pylint: disable=too-many-arguments
query=query,
)
self._create_node(node=new_node, mode=mode)
new_node.refresh()
return new_node

#
# Nodes: DIMENSION
#
def dimension(self, node_name: str) -> "Dimension":
"""
Retrieves a Dimension node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.DIMENSION.value,
)
node = Dimension(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def create_dimension( # pylint: disable=too-many-arguments
self,
name: str,
Expand All @@ -243,26 +200,12 @@ def create_dimension( # pylint: disable=too-many-arguments
query=query,
)
self._create_node(node=new_node, mode=mode)
new_node.refresh()
return new_node

#
# Nodes: METRIC
#
def metric(self, node_name: str) -> "Metric":
"""
Retrieves a Metric node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.METRIC.value,
)
node = Metric(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def create_metric( # pylint: disable=too-many-arguments
self,
name: str,
Expand All @@ -286,35 +229,12 @@ def create_metric( # pylint: disable=too-many-arguments
query=query,
)
self._create_node(node=new_node, mode=mode)
new_node.refresh()
return new_node

#
# Nodes: CUBE
#
def cube(self, node_name: str) -> "Cube": # pragma: no cover
"""
Retrieves a Cube node with that name if one exists.
"""
node_dict = self._get_cube(node_name)
if "name" not in node_dict:
raise DJClientException(f"Cube `{node_name}` does not exist")
dimensions = [
f'{col["node_name"]}.{col["name"]}'
for col in node_dict["cube_elements"]
if col["type"] != "metric"
]
metrics = [
f'{col["node_name"]}.{col["name"]}'
for col in node_dict["cube_elements"]
if col["type"] == "metric"
]
return Cube(
**node_dict,
metrics=metrics,
dimensions=dimensions,
dj_client=self,
)

def create_cube( # pylint: disable=too-many-arguments
self,
name: str,
Expand All @@ -338,4 +258,5 @@ def create_cube( # pylint: disable=too-many-arguments
display_name=display_name,
)
self._create_node(node=new_node, mode=mode) # pragma: no cover
new_node.refresh()
return new_node # pragma: no cover
86 changes: 86 additions & 0 deletions datajunction-clients/python/datajunction/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from datajunction import _internal, models
from datajunction.exceptions import DJClientException
from datajunction.nodes import Cube, Dimension, Metric, Source, Transform


class DJClient(_internal.DJClient):
Expand Down Expand Up @@ -280,3 +281,88 @@ def list_engines(self) -> List[dict]:
{"name": engine["name"], "version": engine["version"]}
for engine in json_response
]

# Read nodes
def source(self, node_name: str) -> Source:
"""
Retrieves a source node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.SOURCE.value,
)
node = Source(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def transform(self, node_name: str) -> Transform:
"""
Retrieves a transform node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.TRANSFORM.value,
)
node = Transform(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def dimension(self, node_name: str) -> "Dimension":
"""
Retrieves a Dimension node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.DIMENSION.value,
)
node = Dimension(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def metric(self, node_name: str) -> "Metric":
"""
Retrieves a Metric node with that name if one exists.
"""
node_dict = self._verify_node_exists(
node_name,
type_=models.NodeType.METRIC.value,
)
node = Metric(
**node_dict,
dj_client=self,
)
node.primary_key = self._primary_key_from_columns(node_dict["columns"])
return node

def cube(self, node_name: str) -> "Cube": # pragma: no cover
"""
Retrieves a Cube node with that name if one exists.
"""
node_dict = self._get_cube(node_name)
if "name" not in node_dict:
raise DJClientException(f"Cube `{node_name}` does not exist")
dimensions = [
f'{col["node_name"]}.{col["name"]}'
for col in node_dict["cube_elements"]
if col["type"] != "metric"
]
metrics = [
f'{col["node_name"]}.{col["name"]}'
for col in node_dict["cube_elements"]
if col["type"] == "metric"
]
return Cube(
**node_dict,
metrics=metrics,
dimensions=dimensions,
dj_client=self,
)
9 changes: 9 additions & 0 deletions datajunction-clients/python/datajunction/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ class Dimension(NodeWithQuery):
query: str
columns: Optional[List[models.Column]]

def linked_nodes(self):
"""
Find all nodes linked to this dimension
"""
return [
node["name"]
for node in self.dj_client._find_nodes_with_dimension(self.name)
]


class Cube(Node): # pylint: disable=abstract-method
"""
Expand Down
3 changes: 3 additions & 0 deletions datajunction-clients/python/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument
mode=NodeMode.PUBLISHED,
)
assert account_type_dim.name == "default.account_type"
assert len(account_type_dim.columns) == 3
assert "default.account_type" in client.list_dimensions(namespace="default")

# transform nodes
Expand All @@ -395,6 +396,8 @@ def test_create_nodes(self, client): # pylint: disable=unused-argument
"default.large_revenue_payments_only"
in client.namespace("default").transforms()
)
assert len(large_revenue_payments_only.columns) == 4

client.transform("default.large_revenue_payments_only")

result = large_revenue_payments_only.add_materialization(
Expand Down
21 changes: 21 additions & 0 deletions datajunction-clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,30 @@ def test_list_nodes(self, client):
"foo.bar.repair_orders_thin",
]

def test_find_nodes_with_dimension(self, client):
"""
Check that `dimension.linked_nodes()` works as expected.
"""
repair_order_dim = client.dimension("default.repair_order")
assert repair_order_dim.linked_nodes() == [
"default.repair_order_details",
"default.avg_repair_price",
"default.total_repair_cost",
"default.total_repair_order_discounts",
"default.avg_repair_order_discounts",
]

#
# Get common metrics and dimensions
#
def test_common_dimensions(self, client):
"""
Test that getting common dimensions for metrics works
"""
dims = client.common_dimensions(
metrics=["default.num_repair_orders", "default.avg_repair_price"],
)
assert len(dims) == 8

#
# SQL and data
Expand Down

0 comments on commit 17d25f4

Please sign in to comment.