diff --git a/src/obnb/label/split/__init__.py b/src/obnb/label/split/__init__.py index 7c2cf61..cce2b6d 100644 --- a/src/obnb/label/split/__init__.py +++ b/src/obnb/label/split/__init__.py @@ -1,4 +1,5 @@ """Genearting data splits from the labelset collection.""" +from obnb.label.split.explicit import ByTermSplit from obnb.label.split.holdout import ( AllHoldout, RandomRatioHoldout, @@ -10,9 +11,6 @@ RatioPartition, ThresholdPartition, ) -from obnb.label.split.explicit import ( - ByTermSplit -) __all__ = classes = [ "AllHoldout", diff --git a/src/obnb/label/split/base.py b/src/obnb/label/split/base.py index e52adab..db8a45a 100644 --- a/src/obnb/label/split/base.py +++ b/src/obnb/label/split/base.py @@ -41,10 +41,9 @@ def __call__( ids: List[str], y: np.ndarray, ) -> Iterator[Tuple[np.ndarray, ...]]: - """ - Split the input ids into multiple splits, e.g. a test, train, validation - split. The means by which this splitting occurs should be defined by - classes that inherit from this base class. + """Split the input ids into multiple splits, e.g. a test, train, validation + split. The means by which this splitting occurs should be defined by classes + that inherit from this base class. Note: Inheriting classes should yield the value instead of returning it, @@ -58,8 +57,9 @@ def __call__( Yields: Iterator of splits. Each split is a tuple of numpy arrays, where each array contains the IDs of the entities in the split. + """ - + raise NotImplementedError diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 1c7e249..1122e4d 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -1,36 +1,34 @@ from typing import Iterable, Iterator, List, Tuple import numpy +from numpy import ndarray from obnb.label.collection import LabelsetCollection from obnb.label.split.base import BaseSplit -from numpy import ndarray class ByTermSplit(BaseSplit): - """ - Produces splits based on an explicit list of terms. Genes - which match each term will be placed in the split corresponding - to that term. - - A split with a single term '*' will act as a catch-all for any - genes that weren't matched by any of the other splits. This would - allow you to, e.g., only retain a specific set of genes in the - training set, and place all others in the test set. - - Note that if the '*' split is not provided, any genes that don't - match any of the other splits will not be present in the returned - splits at all. + """Produces splits based on an explicit list of terms. Genes which match each term + will be placed in the split corresponding to that term. + + A split with a single term '*' will act as a catch-all for any genes that + weren't matched by any of the other splits. This would allow you to, e.g., + only retain a specific set of genes in the training set, and place all + others in the test set. + + Note that if the '*' split is not provided, any genes that don't match any + of the other splits will not be present in the returned splits at all. + """ def __init__( - self, labelset: LabelsetCollection, - split_terms: Iterable[Iterable[str]], - exclusive: bool = False - ) -> None: - """ - Initialize ByTermSplit object with reference labels and terms into - which to create splits. + self, + labelset: LabelsetCollection, + split_terms: Iterable[Iterable[str]], + exclusive: bool = False, + ) -> None: + """Initialize ByTermSplit object with reference labels and terms into which to + create splits. Args: labelset: LabelsetCollection object containing terms for each @@ -40,6 +38,7 @@ def __init__( terms that should be matched to place a gene in that split. exclusive: if True, a gene can occur only once across all the splits; it will belong to the first split in which it occurs. + """ self.labelset = labelset @@ -57,7 +56,7 @@ def __init__( self.long_df = df.melt( id_vars=["Name"], value_vars=df.columns.difference(["Info", "Size", "Name"]), - value_name="Value" + value_name="Value", ).dropna(subset=["Value"]) # group gene id and aggregate terms into a set, which makes @@ -72,12 +71,11 @@ def __init__( super().__init__() def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: - """ - For each gene ID, look up the term it's associated with - in the labelset, and place it in the corresponding split. + """For each gene ID, look up the term it's associated with in the labelset, and + place it in the corresponding split. + + Returns as many splits as there are elements in the split_terms tuple. - Returns as many splits as there are elements in the split_terms - tuple. """ # alias field to shorten the code below @@ -87,10 +85,13 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # term in the split result = [ ( - set([ - id for id in ids + { + id + for id in ids if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms - ]) if terms != {"*"} else None + } + if terms != {"*"} + else None ) for terms in self.split_terms ] @@ -100,13 +101,14 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # the other splits for idx in range(len(result)): if result[idx] is None: - result[idx] = set([ - id for id in ids + result[idx] = { + id + for id in ids if not any( gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms for terms in self.split_terms ) - ]) + } if self.exclusive: # if exclusive, remove genes in the current split that occurred @@ -120,4 +122,4 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # numpy arrays. we cast to list because leaving it as a set would cause # numpy.asarray() to create an array with a single element, the set, # rather than an array with the elements of the list - yield tuple([ numpy.asarray(list(x)) for x in result ]) + yield tuple([numpy.asarray(list(x)) for x in result])