From bc576143743ef168f5e709ef9629a9d8459bda4b Mon Sep 17 00:00:00 2001 From: Vincent M Date: Wed, 24 Jan 2024 19:19:13 +0100 Subject: [PATCH] Debug Survtrace (#44) * add optimizer * solve survtrace convergence problem --- hazardous/survtrace/_bert_layers.py | 15 ++++++-- hazardous/survtrace/_model.py | 60 +++++++++-------------------- 2 files changed, 31 insertions(+), 44 deletions(-) diff --git a/hazardous/survtrace/_bert_layers.py b/hazardous/survtrace/_bert_layers.py index 8f1709c..876a794 100644 --- a/hazardous/survtrace/_bert_layers.py +++ b/hazardous/survtrace/_bert_layers.py @@ -21,12 +21,14 @@ def __init__( hidden_size=16, layer_norm_eps=1e-12, hidden_dropout_prob=0.0, + initializer_range=0.02, ): super().__init__() self.word_embeddings = nn.Embedding(vocab_size + 1, hidden_size) self.num_embeddings = nn.Parameter( torch.randn(1, n_numerical_features, hidden_size) ) + self.num_embeddings.data.normal_(mean=0.0, std=initializer_range) self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.dropout = nn.Dropout(hidden_dropout_prob) @@ -89,6 +91,7 @@ class BertLayer(nn.Module): def __init__(self): super().__init__() self.seq_len_dim = 1 + self.chunk_size_feed_forward = 0 self.attention = BertAttention() self.intermediate = BertIntermediate() self.output = BertOutput() @@ -119,7 +122,10 @@ def forward( ] # add self attentions if we output attention weights hidden_states = apply_chunking_to_forward( - self.feed_forward_chunk, self.seq_len_dim, attention_output + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, ) return hidden_states, self_attentions @@ -510,7 +516,10 @@ def find_pruneable_heads_and_indices( def apply_chunking_to_forward( - forward_fn: Callable[..., torch.Tensor], chunk_dim: int, *input_tensors + forward_fn: Callable[..., torch.Tensor], + chunk_size: int, + chunk_dim: int, + *input_tensors, ) -> torch.Tensor: """ This function chunks the :obj:`input_tensors` into smaller input tensor \ @@ -547,7 +556,7 @@ def forward(self, hidden_states): return apply_chunking_to_forward(self.forward_chunk, \ self.chunk_size_lm_head, self.seq_len_dim, hidden_states) """ - + del chunk_size assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" tensor_shape = input_tensors[0].shape[chunk_dim] assert all( diff --git a/hazardous/survtrace/_model.py b/hazardous/survtrace/_model.py index 50ad2ee..46ddc01 100644 --- a/hazardous/survtrace/_model.py +++ b/hazardous/survtrace/_model.py @@ -11,6 +11,7 @@ from skorch import NeuralNet from skorch.callbacks import Callback, ProgressBar from skorch.dataset import ValidSplit, unpack_data +from torch.optim import Adam from hazardous.utils import get_n_events @@ -22,7 +23,6 @@ ) from ._encoder import SurvFeatureEncoder, SurvTargetEncoder from ._losses import NLLPCHazardLoss -from ._optimizer import BERTAdam from ._utils import pad_col_3d @@ -77,13 +77,13 @@ def __init__( criterion = NLLPCHazardLoss callbacks = [ShapeSetter(), ProgressBar(detect_notebook=False)] - optimizer = BERTAdam # Adam + optimizer = Adam super().__init__( module=module, criterion=criterion, optimizer=optimizer, optimizer__lr=lr, - optimizer__weight_decay_rate=weight_decay, + optimizer__weight_decay=weight_decay, callbacks=callbacks, batch_size=batch_size, device=device, @@ -106,38 +106,6 @@ def initialize_module(self): setattr(self.module, module_name, sub_module) return super().initialize_module() - def initialize_optimizer(self, triggered_directly=None): - """Initialize the model optimizer. If ``self.optimizer__lr`` - is not set, use ``self.lr`` instead. - """ - named_parameters = self.get_all_learnable_params() - _, kwargs = self.get_params_for_optimizer("optimizer", named_parameters) - - # assign no weight decay on these parameters - no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] - param_optimizer = list(self.module.named_parameters()) - optimizer_grouped_parameters = [ - { - "params": [ - p for n, p in param_optimizer if not any(nd in n for nd in no_decay) - ], - "weight_decay": 0, - }, - { - "params": [ - p for n, p in param_optimizer if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] - - self.optimizer_ = BERTAdam( - optimizer_grouped_parameters, - **kwargs, - ) - - return self - def check_data(self, X, y=None): if not hasattr(X, "__dataframe__"): raise TypeError(f"X must be a dataframe. Got {type(X)}") @@ -224,16 +192,20 @@ def train_step_single(self, batch, **fit_params): Xi, yi = unpack_data(batch) loss = 0 all_y_pred = [] - # import ipdb; ipdb.set_trace() + event_multiclass = yi["event"].copy() for event_of_interest in range(1, self.n_events + 1): + yi["event"] = (event_multiclass == event_of_interest).long() fit_params["event_of_interest"] = event_of_interest y_pred = self.infer(Xi, **fit_params) + loss += self.get_loss(y_pred, yi, X=Xi, training=True) all_y_pred.append(y_pred[:, :, None]) + all_y_pred = torch.concatenate( all_y_pred, axis=2 ) # (n_samples, n_time_steps, n_events) loss.backward() + return { "loss": loss, "y_pred": all_y_pred, @@ -260,12 +232,16 @@ def validation_step(self, batch, **fit_params): Xi, yi = unpack_data(batch) loss = 0 all_y_pred = [] + event_multiclass = yi["event"].copy() with torch.no_grad(): for event_of_interest in range(1, self.n_events + 1): + yi["event"] = (event_multiclass == event_of_interest).long() fit_params["event_of_interest"] = event_of_interest y_pred = self.infer(Xi, **fit_params) + loss += self.get_loss(y_pred, yi, X=Xi, training=False) all_y_pred.append(y_pred[:, :, None]) + all_y_pred = torch.concatenate( all_y_pred, axis=2 ) # (n_samples, n_time_steps, n_events) @@ -358,18 +334,19 @@ def __init__( self, init_range=0.02, # BertEmbedding - n_numerical_features=1, - vocab_size=8, + n_numerical_features=1, # * + vocab_size=8, # * hidden_size=16, layer_norm_eps=1e-12, hidden_dropout_prob=0.0, + initializer_range=0.02, # BertEncoder num_hidden_layers=3, # BertCLS intermediate_size=64, - n_events=1, - n_features_in=1, - n_features_out=1, + n_events=1, # * + n_features_in=1, # * + n_features_out=1, # * ): super().__init__() self.init_range = init_range @@ -379,6 +356,7 @@ def __init__( hidden_size=hidden_size, layer_norm_eps=layer_norm_eps, hidden_dropout_prob=hidden_dropout_prob, + initializer_range=initializer_range, ) self.encoder = BertEncoder(num_hidden_layers=num_hidden_layers) self.cls = BertCLSMulti(