Skip to content

Commit

Permalink
Applied pre-commit, committed style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
falquaddoomi committed Nov 15, 2024
1 parent 5749f97 commit 81c37ec
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
4 changes: 1 addition & 3 deletions src/obnb/label/split/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Genearting data splits from the labelset collection."""
from obnb.label.split.explicit import ByTermSplit
from obnb.label.split.holdout import (
AllHoldout,
RandomRatioHoldout,
Expand All @@ -10,9 +11,6 @@
RatioPartition,
ThresholdPartition,
)
from obnb.label.split.explicit import (
ByTermSplit
)

__all__ = classes = [
"AllHoldout",
Expand Down
10 changes: 5 additions & 5 deletions src/obnb/label/split/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ def __call__(
ids: List[str],
y: np.ndarray,
) -> Iterator[Tuple[np.ndarray, ...]]:
"""
Split the input ids into multiple splits, e.g. a test, train, validation
split. The means by which this splitting occurs should be defined by
classes that inherit from this base class.
"""Split the input ids into multiple splits, e.g. a test, train, validation
split. The means by which this splitting occurs should be defined by classes
that inherit from this base class.
Note:
Inheriting classes should yield the value instead of returning it,
Expand All @@ -58,8 +57,9 @@ def __call__(
Yields:
Iterator of splits. Each split is a tuple of numpy arrays, where
each array contains the IDs of the entities in the split.
"""

raise NotImplementedError


Expand Down
70 changes: 36 additions & 34 deletions src/obnb/label/split/explicit.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
from typing import Iterable, Iterator, List, Tuple

import numpy
from numpy import ndarray

from obnb.label.collection import LabelsetCollection
from obnb.label.split.base import BaseSplit
from numpy import ndarray


class ByTermSplit(BaseSplit):
"""
Produces splits based on an explicit list of terms. Genes
which match each term will be placed in the split corresponding
to that term.
A split with a single term '*' will act as a catch-all for any
genes that weren't matched by any of the other splits. This would
allow you to, e.g., only retain a specific set of genes in the
training set, and place all others in the test set.
Note that if the '*' split is not provided, any genes that don't
match any of the other splits will not be present in the returned
splits at all.
"""Produces splits based on an explicit list of terms. Genes which match each term
will be placed in the split corresponding to that term.
A split with a single term '*' will act as a catch-all for any genes that
weren't matched by any of the other splits. This would allow you to, e.g.,
only retain a specific set of genes in the training set, and place all
others in the test set.
Note that if the '*' split is not provided, any genes that don't match any
of the other splits will not be present in the returned splits at all.
"""

def __init__(
self, labelset: LabelsetCollection,
split_terms: Iterable[Iterable[str]],
exclusive: bool = False
) -> None:
"""
Initialize ByTermSplit object with reference labels and terms into
which to create splits.
self,
labelset: LabelsetCollection,
split_terms: Iterable[Iterable[str]],
exclusive: bool = False,
) -> None:
"""Initialize ByTermSplit object with reference labels and terms into which to
create splits.
Args:
labelset: LabelsetCollection object containing terms for each
Expand All @@ -40,6 +38,7 @@ def __init__(
terms that should be matched to place a gene in that split.
exclusive: if True, a gene can occur only once across all the
splits; it will belong to the first split in which it occurs.
"""

self.labelset = labelset
Expand All @@ -57,7 +56,7 @@ def __init__(
self.long_df = df.melt(
id_vars=["Name"],
value_vars=df.columns.difference(["Info", "Size", "Name"]),
value_name="Value"
value_name="Value",
).dropna(subset=["Value"])

# group gene id and aggregate terms into a set, which makes
Expand All @@ -72,12 +71,11 @@ def __init__(
super().__init__()

def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]:
"""
For each gene ID, look up the term it's associated with
in the labelset, and place it in the corresponding split.
"""For each gene ID, look up the term it's associated with in the labelset, and
place it in the corresponding split.
Returns as many splits as there are elements in the split_terms tuple.
Returns as many splits as there are elements in the split_terms
tuple.
"""

# alias field to shorten the code below
Expand All @@ -87,10 +85,13 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]:
# term in the split
result = [
(
set([
id for id in ids
{
id
for id in ids
if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms
]) if terms != {"*"} else None
}
if terms != {"*"}
else None
)
for terms in self.split_terms
]
Expand All @@ -100,13 +101,14 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]:
# the other splits
for idx in range(len(result)):
if result[idx] is None:
result[idx] = set([
id for id in ids
result[idx] = {
id
for id in ids
if not any(
gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms
for terms in self.split_terms
)
])
}

if self.exclusive:
# if exclusive, remove genes in the current split that occurred
Expand All @@ -120,4 +122,4 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]:
# numpy arrays. we cast to list because leaving it as a set would cause
# numpy.asarray() to create an array with a single element, the set,
# rather than an array with the elements of the list
yield tuple([ numpy.asarray(list(x)) for x in result ])
yield tuple([numpy.asarray(list(x)) for x in result])

0 comments on commit 81c37ec

Please sign in to comment.