Skip to content

Commit

Permalink
attempt to improve node info subclass typing
Browse files Browse the repository at this point in the history
method from #26 (comment)
  • Loading branch information
dhimmel committed Jul 12, 2023
1 parent b68411d commit 18c2c9b
Showing 3 changed files with 23 additions and 10 deletions.
3 changes: 2 additions & 1 deletion nxontology/node.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
# Type definitions. networkx does not declare types.
# https://github.com/networkx/networkx/issues/3988#issuecomment-639969263
NodeT = TypeVar("NodeT", bound=Hashable)
NodeInfoT = TypeVar("NodeInfoT")


class NodeInfo(Freezable, Generic[NodeT]):
@@ -35,7 +36,7 @@ class NodeInfo(Freezable, Generic[NodeT]):
Each ic_metric has a scaled version accessible by adding a _scaled suffix.
"""

def __init__(self, nxo: NXOntology[NodeT], node: NodeT):
def __init__(self, nxo: NXOntology[NodeT, NodeInfoT], node: NodeT):
if node not in nxo.graph:
raise NodeNotFound(f"{node} not in graph.")
self.nxo = nxo
18 changes: 13 additions & 5 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import itertools
import json
import logging
from abc import abstractmethod
from os import PathLike, fspath
from typing import Any, Generic, Iterable, cast

@@ -12,7 +13,7 @@
from networkx.algorithms.isolate import isolates
from networkx.readwrite.json_graph import node_link_data, node_link_graph

from nxontology.node import NodeT
from nxontology.node import NodeInfoT, NodeT

from .exceptions import DuplicateError, NodeNotFound
from .node import NodeInfo
@@ -22,7 +23,7 @@
logger = logging.getLogger(__name__)


class NXOntology(Freezable, Generic[NodeT]):
class NXOntologyBase(Freezable, Generic[NodeT, NodeInfoT]):
"""
Encapsulate a networkx.DiGraph to represent an ontology.
Regarding edge directionality, parent terms should point to child term.
@@ -77,7 +78,7 @@ def write_node_link_json(self, path: str | PathLike[str]) -> None:
write_file.write("\n") # json.dump does not include a trailing newline

@classmethod
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[NodeT]:
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntologyBase[NodeT]:
"""
Retrun a new graph from node-link format as written by `write_node_link_json`.
"""
@@ -213,7 +214,8 @@ def compute_similarities(
yield metrics

@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
@abstractmethod
def _get_node_info_cls(cls) -> type[NodeInfoT]:
"""
Return the Node_Info class to use for this ontology.
Subclasses can override this to use a custom Node_Info class.
@@ -222,7 +224,7 @@ def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
"""
return NodeInfo

def node_info(self, node: NodeT) -> NodeInfo[NodeT]:
def node_info(self, node: NodeT) -> NodeInfoT:
"""
Return Node_Info instance for `node`.
If frozen, cache node info in `self._node_info_cache`.
@@ -306,3 +308,9 @@ def set_graph_attributes(
self.graph.graph["node_identifier_attribute"] = node_identifier_attribute
if node_url_attribute:
self.graph.graph["node_url_attribute"] = node_url_attribute


class NXOntology(NXOntologyBase[NodeT, NodeInfo[NodeT]]):
@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
return NodeInfo
12 changes: 8 additions & 4 deletions nxontology/similarity.py
Original file line number Diff line number Diff line change
@@ -3,12 +3,14 @@
import math
from typing import TYPE_CHECKING, Any, Generic

from nxontology.ontology import NXOntologyBase

if TYPE_CHECKING:
from nxontology.ontology import NXOntology
pass

from networkx import shortest_path_length

from nxontology.node import NodeInfo, NodeT
from nxontology.node import NodeInfo, NodeInfoT, NodeT
from nxontology.utils import Freezable, cache_on_frozen


@@ -29,7 +31,9 @@ class Similarity(Freezable, Generic[NodeT]):
"batet_log",
]

def __init__(self, nxo: NXOntology[NodeT], node_0: NodeT, node_1: NodeT):
def __init__(
self, nxo: NXOntologyBase[NodeT, NodeInfoT], node_0: NodeT, node_1: NodeT
):
self.nxo = nxo
self.node_0 = node_0
self.node_1 = node_1
@@ -125,7 +129,7 @@ class SimilarityIC(Similarity[NodeT]):

def __init__(
self,
nxo: NXOntology[NodeT],
nxo: NXOntologyBase[NodeT],
node_0: NodeT,
node_1: NodeT,
ic_metric: str = "intrinsic_ic_sanchez",

0 comments on commit 18c2c9b

Please sign in to comment.