-
Notifications
You must be signed in to change notification settings - Fork 857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Training With Sparse Matrices #1629
Closed
Closed
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
b1c805a
Set up basic scaffolding for sparse amtrix support
talolard 9624761
Added logic and tests to load L_ind as a list of tuples
talolard e0c3f2d
refactor clique data creation so that it happens during create_tree
talolard 212db85
Added a function that predicts probs from a 'cliqueset' so that we do…
talolard dd3b0c7
sparse predictor returns a dict of [tuple,list]
talolard b7352fd
refactor to SparseLabelModel class
talolard c7692bf
Continued refactoring towards classes and improved tests
talolard 664ac6e
Moved KnownDimensions type to sparse_label_model
talolard bcb39f7
Differentiate between different kinds of sparse inputs
talolard 17e935e
Added tests that compare a sparse models output to regular model
talolard 133c1d2
Added tests to check that event sparse model trains the same as a reg…
talolard 8996bec
Added documentation
talolard c46a121
Pass mypy checks
talolard 01f6b56
Pass mypy checks
talolard a2f267a
Pass mypy checks
talolard 735ab51
Comply with tox docstrings
talolard e190d04
Ensure seed setting happens first in training
talolard f9b2629
Resolve 'Nits'
talolard d579832
Pass tox
talolard File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,9 +129,10 @@ dmypy.json | |
# Editors | ||
.vscode/ | ||
.code-workspace* | ||
|
||
.idea/ | ||
# Dask | ||
dask-worker-space/ | ||
|
||
# nohup | ||
nohup.out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
from .baselines import MajorityClassVoter, MajorityLabelVoter, RandomVoter # noqa: F401 | ||
from .label_model import LabelModel # noqa: F401 | ||
from .sparse_data_helpers import ( | ||
train_model_from_known_objective, | ||
train_model_from_sparse_event_cooccurence, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Sparse Data Helpers | ||
|
||
Indexing throughout this module is 0 based, with the assumption that "abstains" are ommited. | ||
|
||
When working with larger datasets, it can be convenient to load the data in sparse format. This module | ||
provides utilities to do so. We provide functions for a number of cases. | ||
|
||
The user has the AugmentedMatrix (L_ind) in tuple form. AugmentedMatrix is of shape (num_examples,numfuncs*num_classes) | ||
and the user has a list of tuples (i,j) that indicate that event j occoured for example i. | ||
|
||
The user has a list of 3-tuples(i,j,k) such that for document i, labeling function j predicted class k. | ||
|
||
The user has a list of 3-tuples (i,j,c) where i and j range over [0,num_funcs*num_classes] such that | ||
the events i and j were observed to have co-occur c times. | ||
|
||
The user has a list of 3-tuples (i,j,f) where i and j range over [0,num_funcs*num_classes] such that | ||
the events i and j co-occur with frequency f where f is in (0,1] | ||
|
||
""" | ||
from snorkel.labeling.model.label_model import LabelModel | ||
from typing import List, Tuple, Iterable, Dict | ||
from scipy.sparse import csr_matrix | ||
import numpy as np | ||
import torch | ||
from snorkel.types.data import KnownDimensions | ||
|
||
|
||
def predict_probs_from_cliqueset( | ||
trained_model: LabelModel, cliqueset_indice_list: Iterable[Iterable[int]] | ||
): | ||
""" | ||
This function can make inference many orders of magnitude faster for larger datasets. | ||
|
||
In the data representation of L_ind where each row is a document and each column corresponds to an event " | ||
function x predicted class y", the 1s on L_ind essentially define a fully connected graph, or cliqueset. | ||
while their are num_classes^num_functions possible cliquesets, in practice we'll see a very small subset of | ||
those. | ||
In our exerpiments, where num_functions=40 and num_classes=3 we observed 600 cliquesets whereas 3^40 were possible. | ||
|
||
This function receives a trained model, and a list of cliquesets (indexed by event_id "func_id*num_labels+label_id") | ||
loads those in a sparse format and returns to predictions keyed by cliqueset | ||
|
||
|
||
|
||
""" | ||
rows = [] | ||
cols = [] | ||
data = [] | ||
for num, cs in enumerate(cliqueset_indice_list): | ||
for event_id in cs: | ||
rows.append(num) | ||
cols.append(event_id) | ||
data.append(1) | ||
sparse_input_l_ind = csr_matrix( | ||
(data, (rows, cols)), shape=(len(rows), trained_model.d) | ||
) | ||
predicted_probs = trained_model.predict_proba(sparse_input_l_ind.todense(),is_augmented=True) | ||
result_dict: Dict[tuple, np.array] ={} | ||
for cs, probs in zip(cliqueset_indice_list, predicted_probs): | ||
result_dict[tuple(cs)] = probs | ||
return result_dict | ||
|
||
|
||
def train_model_from_known_objective( | ||
talolard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
objective: np.array, known_dimensions: KnownDimensions, **kwargs | ||
): | ||
model = LabelModel(cardinality=known_dimensions.num_classes, **kwargs) | ||
model._set_constants(known_dimensions=known_dimensions) | ||
model.O = torch.from_numpy(objective) | ||
model._common_training_preamble() | ||
model._common_training_loop() | ||
return model | ||
|
||
|
||
def train_model_from_sparse_event_cooccurence( | ||
talolard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sparse_event_cooccurence: List[Tuple[int, int, int]], | ||
known_dimensions: KnownDimensions, | ||
): | ||
objective = _prepare_objective_from_sparse_event_cooccurence( | ||
sparse_event_cooccurence, known_dimensions | ||
) | ||
return train_model_from_known_objective( | ||
objective=objective, known_dimensions=known_dimensions | ||
) | ||
|
||
|
||
def _prepare_objective_from_sparse_event_cooccurence( | ||
sparse_event_cooccurence: List[Tuple[int, int, int]], | ||
known_dimensions: KnownDimensions, | ||
): | ||
sparse_L_ind = _prepare_sparse_L_ind(known_dimensions, sparse_event_cooccurence) | ||
objective = (sparse_L_ind.T @ sparse_L_ind) / known_dimensions.num_examples | ||
return objective.todense() | ||
|
||
|
||
def _prepare_sparse_L_ind(known_dimensions, sparse_event_cooccurence): | ||
rows = [] | ||
cols = [] | ||
data = [] | ||
for (row, col, count) in sparse_event_cooccurence: | ||
rows.append(row) | ||
cols.append(col) | ||
data.append(count) | ||
rows = np.array(rows) | ||
cols = np.array(cols) | ||
sparse_L_ind = csr_matrix( | ||
(data, (rows, cols),), # Notice that this is a tuple with a tuple | ||
shape=(known_dimensions.num_examples, known_dimensions.num_events), | ||
) | ||
return sparse_L_ind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,18 @@ | ||
from typing import Any, Mapping, Sequence | ||
from typing import Any, Mapping, Sequence, NamedTuple, Optional | ||
|
||
DataPoint = Any | ||
DataPoints = Sequence[DataPoint] | ||
|
||
Field = Any | ||
FieldMap = Mapping[str, Field] | ||
class KnownDimensions(NamedTuple): | ||
talolard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_functions: int | ||
num_classes: int | ||
num_examples: Optional[int] | ||
|
||
@property | ||
def num_events(self): | ||
""" | ||
How many indicator random variables do we have (1 per event) | ||
""" | ||
return self.num_functions * self.num_classes |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cool to add this to .gitignore, but then let's not add all the files in the .idea directory. Ditto with
workspace.code-workspace
—let's not add it to the repo, but if you want to add that type of file to .gitignore, that's fine.