-
Notifications
You must be signed in to change notification settings - Fork 857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Training With Sparse Matrices #1629
Closed
Closed
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
b1c805a
Set up basic scaffolding for sparse amtrix support
talolard 9624761
Added logic and tests to load L_ind as a list of tuples
talolard e0c3f2d
refactor clique data creation so that it happens during create_tree
talolard 212db85
Added a function that predicts probs from a 'cliqueset' so that we do…
talolard dd3b0c7
sparse predictor returns a dict of [tuple,list]
talolard b7352fd
refactor to SparseLabelModel class
talolard c7692bf
Continued refactoring towards classes and improved tests
talolard 664ac6e
Moved KnownDimensions type to sparse_label_model
talolard bcb39f7
Differentiate between different kinds of sparse inputs
talolard 17e935e
Added tests that compare a sparse models output to regular model
talolard 133c1d2
Added tests to check that event sparse model trains the same as a reg…
talolard 8996bec
Added documentation
talolard c46a121
Pass mypy checks
talolard 01f6b56
Pass mypy checks
talolard a2f267a
Pass mypy checks
talolard 735ab51
Comply with tox docstrings
talolard e190d04
Ensure seed setting happens first in training
talolard f9b2629
Resolve 'Nits'
talolard d579832
Pass tox
talolard File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,9 +129,10 @@ dmypy.json | |
# Editors | ||
.vscode/ | ||
.code-workspace* | ||
|
||
.idea/ | ||
# Dask | ||
dask-worker-space/ | ||
|
||
# nohup | ||
nohup.out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,9 @@ | |
from snorkel.labeling.model.base_labeler import BaseLabeler | ||
from snorkel.labeling.model.graph_utils import get_clique_tree | ||
from snorkel.labeling.model.logger import Logger | ||
from snorkel.labeling.model.sparse_label_model.sparse_label_model_helpers import ( | ||
KnownDimensions, | ||
) | ||
from snorkel.types import Config | ||
from snorkel.utils.config_utils import merge_config | ||
from snorkel.utils.lr_schedulers import LRSchedulerConfig | ||
|
@@ -188,19 +191,6 @@ def _get_augmented_label_matrix( | |
# Create a helper data structure which maps cliques (as tuples of member | ||
# sources) --> {start_index, end_index, maximal_cliques}, where | ||
# the last value is a set of indices in this data structure | ||
self.c_data: Dict[int, _CliqueData] = {} | ||
for i in range(self.m): | ||
self.c_data[i] = _CliqueData( | ||
start_index=i * self.cardinality, | ||
end_index=(i + 1) * self.cardinality, | ||
max_cliques=set( | ||
[ | ||
j | ||
for j in self.c_tree.nodes() | ||
if i in self.c_tree.node[j]["members"] | ||
] | ||
), | ||
) | ||
|
||
L_ind = self._create_L_ind(L) | ||
|
||
|
@@ -225,6 +215,21 @@ def _get_augmented_label_matrix( | |
else: | ||
return L_ind | ||
|
||
def _calculate_clique_data(self) -> None: | ||
self.c_data: Dict[int, _CliqueData] = {} | ||
for i in range(self.m): | ||
self.c_data[i] = _CliqueData( | ||
start_index=i * self.cardinality, | ||
end_index=(i + 1) * self.cardinality, | ||
max_cliques=set( | ||
[ | ||
j | ||
for j in self.c_tree.nodes() | ||
if i in self.c_tree.node[j]["members"] | ||
] | ||
), | ||
) | ||
|
||
def _build_mask(self) -> None: | ||
"""Build mask applied to O^{-1}, O for the matrix approx constraint.""" | ||
self.mask = torch.ones(self.d, self.d).bool() | ||
|
@@ -252,6 +257,12 @@ def _generate_O(self, L: np.ndarray, higher_order: bool = False) -> None: | |
""" | ||
L_aug = self._get_augmented_label_matrix(L, higher_order=higher_order) | ||
self.d = L_aug.shape[1] | ||
self._generate_O_from_L_aug(L_aug) | ||
|
||
def _generate_O_from_L_aug(self, L_aug: np.ndarray) -> None: | ||
""" Generates O from L_aug. Extracted to a seperate method for the sake of testing | ||
|
||
""" | ||
self.O = ( | ||
torch.from_numpy(L_aug.T @ L_aug / self.n).float().to(self.config.device) | ||
) | ||
|
@@ -377,7 +388,7 @@ def get_weights(self) -> np.ndarray: | |
accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.cpu().detach().numpy()).sum() | ||
return np.clip(accs / self.coverage, 1e-6, 1.0) | ||
|
||
def predict_proba(self, L: np.ndarray) -> np.ndarray: | ||
def predict_proba(self, L: np.ndarray, is_augmented: bool = False) -> np.ndarray: | ||
r"""Return label probabilities P(Y | \lambda). | ||
|
||
Parameters | ||
|
@@ -400,9 +411,14 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray: | |
[0., 1.], | ||
[0., 1.]]) | ||
""" | ||
L_shift = L + 1 # convert to {0, 1, ..., k} | ||
self._set_constants(L_shift) | ||
L_aug = self._get_augmented_label_matrix(L_shift) | ||
if not is_augmented: | ||
# This is the usual mode | ||
L_shift = L + 1 # convert to {0, 1, ..., k} | ||
self._set_constants(L_shift) # TODO - Why do we need this here ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._get_augmented_label_matrix uses at least self.cardinality, which is set in this method. Remove the TODO? |
||
L_aug = self._get_augmented_label_matrix(L_shift) | ||
else: | ||
# The data came in augmented format, and constants are already set | ||
L_aug = L | ||
mu = self.mu.cpu().detach().numpy() | ||
jtm = np.ones(L_aug.shape[1]) | ||
|
||
|
@@ -580,15 +596,35 @@ def _set_class_balance( | |
) | ||
self.P = torch.diag(torch.from_numpy(self.p)).float().to(self.config.device) | ||
|
||
def _set_constants(self, L: np.ndarray) -> None: | ||
self.n, self.m = L.shape | ||
def _set_constants( | ||
self, | ||
L: Optional[np.ndarray] = None, | ||
known_dimensions: Optional[KnownDimensions] = None, | ||
) -> None: | ||
if L is None and known_dimensions is None: | ||
raise ValueError( | ||
"You must either provide a LabelMatrix or specify known_dimensions" | ||
) | ||
elif known_dimensions is not None: | ||
self.n = known_dimensions.num_examples | ||
self.m = known_dimensions.num_functions | ||
self.d = known_dimensions.num_events | ||
self.cardinality = known_dimensions.num_classes | ||
elif L is not None: | ||
# We know L is not none, but the linter can't figure it out ... | ||
self.n, self.m = L.shape | ||
else: | ||
raise ValueError( | ||
"Something impossible happened. This is here for the sake of the linter" | ||
) | ||
if self.m < 3: | ||
raise ValueError("L_train should have at least 3 labeling functions") | ||
self.t = 1 | ||
|
||
def _create_tree(self) -> None: | ||
nodes = range(self.m) | ||
self.c_tree = get_clique_tree(nodes, []) | ||
self._calculate_clique_data() | ||
|
||
def _execute_logging(self, loss: torch.Tensor) -> Metrics: | ||
self.eval() | ||
|
@@ -607,7 +643,6 @@ def _execute_logging(self, loss: torch.Tensor) -> Metrics: | |
# Reset running loss and examples counts | ||
self.running_loss = 0.0 | ||
self.running_examples = 0 | ||
|
||
self.train() | ||
return metrics_dict | ||
|
||
|
@@ -861,13 +896,6 @@ def fit( | |
>>> label_model.fit(L, class_balance=[0.7, 0.3], n_epochs=200, l2=0.4) | ||
""" | ||
# Set random seed | ||
self.train_config: TrainConfig = merge_config( # type:ignore | ||
TrainConfig(), kwargs # type:ignore | ||
) | ||
# Update base config so that it includes all parameters | ||
random.seed(self.train_config.seed) | ||
np.random.seed(self.train_config.seed) | ||
torch.manual_seed(self.train_config.seed) | ||
|
||
L_shift = L_train + 1 # convert to {0, 1, ..., k} | ||
if L_shift.max() > self.cardinality: | ||
|
@@ -876,15 +904,41 @@ def fit( | |
) | ||
|
||
self._set_constants(L_shift) | ||
self._set_class_balance(class_balance, Y_dev) | ||
self._create_tree() | ||
self._common_training_preamble( | ||
talolard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class_balance=class_balance, Y_dev=Y_dev, **kwargs | ||
) | ||
lf_analysis = LFAnalysis(L_train) | ||
self.coverage = lf_analysis.lf_coverages() | ||
|
||
# Compute O and initialize params | ||
if self.config.verbose: # pragma: no cover | ||
logging.info("Computing O...") | ||
self._generate_O(L_shift) | ||
self._common_training_loop() | ||
|
||
def _common_training_preamble( | ||
talolard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
Y_dev: Optional[np.ndarray] = None, | ||
class_balance: Optional[List[float]] = None, | ||
**kwargs: Any | ||
) -> None: | ||
""" | ||
Performs the training preamble, regardless of user input | ||
""" | ||
self.train_config: TrainConfig = merge_config( # type:ignore | ||
TrainConfig(), kwargs # type:ignore | ||
) | ||
# Update base config so that it includes all parameters | ||
random.seed(self.train_config.seed) | ||
np.random.seed(self.train_config.seed) | ||
torch.manual_seed(self.train_config.seed) | ||
self._set_class_balance(class_balance, Y_dev) | ||
self._create_tree() | ||
|
||
def _common_training_loop(self) -> None: | ||
""" | ||
Training Logic that is shared across different fit methods, irrespective of the user input format | ||
""" | ||
self._init_params() | ||
|
||
# Estimate \mu | ||
|
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cool to add this to .gitignore, but then let's not add all the files in the .idea directory. Ditto with
workspace.code-workspace
—let's not add it to the repo, but if you want to add that type of file to .gitignore, that's fine.