-
Notifications
You must be signed in to change notification settings - Fork 857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Training With Sparse Matrices #1629
Changes from all commits
b1c805a
9624761
e0c3f2d
212db85
dd3b0c7
b7352fd
c7692bf
664ac6e
bcb39f7
17e935e
133c1d2
8996bec
c46a121
01f6b56
a2f267a
735ab51
e190d04
f9b2629
d579832
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,9 +129,10 @@ dmypy.json | |
# Editors | ||
.vscode/ | ||
.code-workspace* | ||
|
||
.idea/ | ||
# Dask | ||
dask-worker-space/ | ||
|
||
# nohup | ||
nohup.out | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,9 @@ | |
from snorkel.labeling.model.base_labeler import BaseLabeler | ||
from snorkel.labeling.model.graph_utils import get_clique_tree | ||
from snorkel.labeling.model.logger import Logger | ||
from snorkel.labeling.model.sparse_label_model.sparse_label_model_helpers import ( | ||
KnownDimensions, | ||
) | ||
from snorkel.types import Config | ||
from snorkel.utils.config_utils import merge_config | ||
from snorkel.utils.lr_schedulers import LRSchedulerConfig | ||
|
@@ -188,19 +191,6 @@ def _get_augmented_label_matrix( | |
# Create a helper data structure which maps cliques (as tuples of member | ||
# sources) --> {start_index, end_index, maximal_cliques}, where | ||
# the last value is a set of indices in this data structure | ||
self.c_data: Dict[int, _CliqueData] = {} | ||
for i in range(self.m): | ||
self.c_data[i] = _CliqueData( | ||
start_index=i * self.cardinality, | ||
end_index=(i + 1) * self.cardinality, | ||
max_cliques=set( | ||
[ | ||
j | ||
for j in self.c_tree.nodes() | ||
if i in self.c_tree.node[j]["members"] | ||
] | ||
), | ||
) | ||
|
||
L_ind = self._create_L_ind(L) | ||
|
||
|
@@ -225,6 +215,21 @@ def _get_augmented_label_matrix( | |
else: | ||
return L_ind | ||
|
||
def _calculate_clique_data(self) -> None: | ||
self.c_data: Dict[int, _CliqueData] = {} | ||
for i in range(self.m): | ||
self.c_data[i] = _CliqueData( | ||
start_index=i * self.cardinality, | ||
end_index=(i + 1) * self.cardinality, | ||
max_cliques=set( | ||
[ | ||
j | ||
for j in self.c_tree.nodes() | ||
if i in self.c_tree.node[j]["members"] | ||
] | ||
), | ||
) | ||
|
||
def _build_mask(self) -> None: | ||
"""Build mask applied to O^{-1}, O for the matrix approx constraint.""" | ||
self.mask = torch.ones(self.d, self.d).bool() | ||
|
@@ -252,6 +257,10 @@ def _generate_O(self, L: np.ndarray, higher_order: bool = False) -> None: | |
""" | ||
L_aug = self._get_augmented_label_matrix(L, higher_order=higher_order) | ||
self.d = L_aug.shape[1] | ||
self._generate_O_from_L_aug(L_aug) | ||
|
||
def _generate_O_from_L_aug(self, L_aug: np.ndarray) -> None: | ||
"""Generate O from L_aug. Extracted to a seperate method for the sake of testing.""" | ||
self.O = ( | ||
torch.from_numpy(L_aug.T @ L_aug / self.n).float().to(self.config.device) | ||
) | ||
|
@@ -377,7 +386,7 @@ def get_weights(self) -> np.ndarray: | |
accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.cpu().detach().numpy()).sum() | ||
return np.clip(accs / self.coverage, 1e-6, 1.0) | ||
|
||
def predict_proba(self, L: np.ndarray) -> np.ndarray: | ||
def predict_proba(self, L: np.ndarray, is_augmented: bool = False) -> np.ndarray: | ||
r"""Return label probabilities P(Y | \lambda). | ||
|
||
Parameters | ||
|
@@ -400,9 +409,14 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray: | |
[0., 1.], | ||
[0., 1.]]) | ||
""" | ||
L_shift = L + 1 # convert to {0, 1, ..., k} | ||
self._set_constants(L_shift) | ||
L_aug = self._get_augmented_label_matrix(L_shift) | ||
if not is_augmented: | ||
# This is the usual mode | ||
L_shift = L + 1 # convert to {0, 1, ..., k} | ||
self._set_constants(L_shift) # TODO - Why do we need this here ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._get_augmented_label_matrix uses at least self.cardinality, which is set in this method. Remove the TODO? |
||
L_aug = self._get_augmented_label_matrix(L_shift) | ||
else: | ||
# The data came in augmented format, and constants are already set | ||
L_aug = L | ||
mu = self.mu.cpu().detach().numpy() | ||
jtm = np.ones(L_aug.shape[1]) | ||
|
||
|
@@ -580,15 +594,35 @@ def _set_class_balance( | |
) | ||
self.P = torch.diag(torch.from_numpy(self.p)).float().to(self.config.device) | ||
|
||
def _set_constants(self, L: np.ndarray) -> None: | ||
self.n, self.m = L.shape | ||
def _set_constants( | ||
self, | ||
L: Optional[np.ndarray] = None, | ||
known_dimensions: Optional[KnownDimensions] = None, | ||
) -> None: | ||
if L is None and known_dimensions is None: | ||
raise ValueError( | ||
"You must either provide a LabelMatrix or specify known_dimensions" | ||
) | ||
elif known_dimensions is not None: | ||
self.n = known_dimensions.num_examples | ||
self.m = known_dimensions.num_functions | ||
self.d = known_dimensions.num_events | ||
self.cardinality = known_dimensions.num_classes | ||
elif L is not None: | ||
# We know L is not none, but the linter can't figure it out ... | ||
self.n, self.m = L.shape | ||
else: | ||
raise ValueError( | ||
"Something impossible happened. This is here for the sake of the linter" | ||
) | ||
if self.m < 3: | ||
raise ValueError("L_train should have at least 3 labeling functions") | ||
self.t = 1 | ||
|
||
def _create_tree(self) -> None: | ||
nodes = range(self.m) | ||
self.c_tree = get_clique_tree(nodes, []) | ||
self._calculate_clique_data() | ||
|
||
def _execute_logging(self, loss: torch.Tensor) -> Metrics: | ||
self.eval() | ||
|
@@ -607,7 +641,6 @@ def _execute_logging(self, loss: torch.Tensor) -> Metrics: | |
# Reset running loss and examples counts | ||
self.running_loss = 0.0 | ||
self.running_examples = 0 | ||
|
||
self.train() | ||
return metrics_dict | ||
|
||
|
@@ -861,13 +894,7 @@ def fit( | |
>>> label_model.fit(L, class_balance=[0.7, 0.3], n_epochs=200, l2=0.4) | ||
""" | ||
# Set random seed | ||
self.train_config: TrainConfig = merge_config( # type:ignore | ||
TrainConfig(), kwargs # type:ignore | ||
) | ||
# Update base config so that it includes all parameters | ||
random.seed(self.train_config.seed) | ||
np.random.seed(self.train_config.seed) | ||
torch.manual_seed(self.train_config.seed) | ||
self._set_config_and_seed(**kwargs) | ||
|
||
L_shift = L_train + 1 # convert to {0, 1, ..., k} | ||
if L_shift.max() > self.cardinality: | ||
|
@@ -876,15 +903,37 @@ def fit( | |
) | ||
|
||
self._set_constants(L_shift) | ||
self._set_class_balance(class_balance, Y_dev) | ||
self._create_tree() | ||
self._training_preamble(class_balance=class_balance, Y_dev=Y_dev, **kwargs) | ||
lf_analysis = LFAnalysis(L_train) | ||
self.coverage = lf_analysis.lf_coverages() | ||
|
||
# Compute O and initialize params | ||
if self.config.verbose: # pragma: no cover | ||
logging.info("Computing O...") | ||
self._generate_O(L_shift) | ||
self._training_loop() | ||
|
||
def _training_preamble( | ||
self, | ||
Y_dev: Optional[np.ndarray] = None, | ||
class_balance: Optional[List[float]] = None, | ||
**kwargs: Any | ||
) -> None: | ||
"""Perform the training preamble, regardless of user input.""" | ||
np.random.seed(self.train_config.seed) | ||
torch.manual_seed(self.train_config.seed) | ||
self._set_class_balance(class_balance, Y_dev) | ||
self._create_tree() | ||
|
||
def _set_config_and_seed(self, **kwargs: Any) -> None: | ||
self.train_config: TrainConfig = merge_config( # type:ignore | ||
TrainConfig(), kwargs # type:ignore | ||
) | ||
# Update base config so that it includes all parameters | ||
random.seed(self.train_config.seed) | ||
|
||
def _training_loop(self) -> None: | ||
"""Perform training logic that is shared across different fit methods, irrespective of the user input format.""" | ||
self._init_params() | ||
|
||
# Estimate \mu | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cool to add this to .gitignore, but then let's not add all the files in the .idea directory. Ditto with
workspace.code-workspace
—let's not add it to the repo, but if you want to add that type of file to .gitignore, that's fine.