Skip to content

Commit

Permalink
Debug Survtrace (#44)
Browse files Browse the repository at this point in the history
* add optimizer

* solve survtrace convergence problem
  • Loading branch information
Vincent-Maladiere authored Jan 24, 2024
1 parent 3aa776b commit bc57614
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 44 deletions.
15 changes: 12 additions & 3 deletions hazardous/survtrace/_bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def __init__(
hidden_size=16,
layer_norm_eps=1e-12,
hidden_dropout_prob=0.0,
initializer_range=0.02,
):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size + 1, hidden_size)
self.num_embeddings = nn.Parameter(
torch.randn(1, n_numerical_features, hidden_size)
)
self.num_embeddings.data.normal_(mean=0.0, std=initializer_range)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.dropout = nn.Dropout(hidden_dropout_prob)

Expand Down Expand Up @@ -89,6 +91,7 @@ class BertLayer(nn.Module):
def __init__(self):
super().__init__()
self.seq_len_dim = 1
self.chunk_size_feed_forward = 0
self.attention = BertAttention()
self.intermediate = BertIntermediate()
self.output = BertOutput()
Expand Down Expand Up @@ -119,7 +122,10 @@ def forward(
] # add self attentions if we output attention weights

hidden_states = apply_chunking_to_forward(
self.feed_forward_chunk, self.seq_len_dim, attention_output
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)

return hidden_states, self_attentions
Expand Down Expand Up @@ -510,7 +516,10 @@ def find_pruneable_heads_and_indices(


def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_dim: int, *input_tensors
forward_fn: Callable[..., torch.Tensor],
chunk_size: int,
chunk_dim: int,
*input_tensors,
) -> torch.Tensor:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor \
Expand Down Expand Up @@ -547,7 +556,7 @@ def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, \
self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""

del chunk_size
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
tensor_shape = input_tensors[0].shape[chunk_dim]
assert all(
Expand Down
60 changes: 19 additions & 41 deletions hazardous/survtrace/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skorch import NeuralNet
from skorch.callbacks import Callback, ProgressBar
from skorch.dataset import ValidSplit, unpack_data
from torch.optim import Adam

from hazardous.utils import get_n_events

Expand All @@ -22,7 +23,6 @@
)
from ._encoder import SurvFeatureEncoder, SurvTargetEncoder
from ._losses import NLLPCHazardLoss
from ._optimizer import BERTAdam
from ._utils import pad_col_3d


Expand Down Expand Up @@ -77,13 +77,13 @@ def __init__(

criterion = NLLPCHazardLoss
callbacks = [ShapeSetter(), ProgressBar(detect_notebook=False)]
optimizer = BERTAdam # Adam
optimizer = Adam
super().__init__(
module=module,
criterion=criterion,
optimizer=optimizer,
optimizer__lr=lr,
optimizer__weight_decay_rate=weight_decay,
optimizer__weight_decay=weight_decay,
callbacks=callbacks,
batch_size=batch_size,
device=device,
Expand All @@ -106,38 +106,6 @@ def initialize_module(self):
setattr(self.module, module_name, sub_module)
return super().initialize_module()

def initialize_optimizer(self, triggered_directly=None):
"""Initialize the model optimizer. If ``self.optimizer__lr``
is not set, use ``self.lr`` instead.
"""
named_parameters = self.get_all_learnable_params()
_, kwargs = self.get_params_for_optimizer("optimizer", named_parameters)

# assign no weight decay on these parameters
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
param_optimizer = list(self.module.named_parameters())
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]

self.optimizer_ = BERTAdam(
optimizer_grouped_parameters,
**kwargs,
)

return self

def check_data(self, X, y=None):
if not hasattr(X, "__dataframe__"):
raise TypeError(f"X must be a dataframe. Got {type(X)}")
Expand Down Expand Up @@ -224,16 +192,20 @@ def train_step_single(self, batch, **fit_params):
Xi, yi = unpack_data(batch)
loss = 0
all_y_pred = []
# import ipdb; ipdb.set_trace()
event_multiclass = yi["event"].copy()
for event_of_interest in range(1, self.n_events + 1):
yi["event"] = (event_multiclass == event_of_interest).long()
fit_params["event_of_interest"] = event_of_interest
y_pred = self.infer(Xi, **fit_params)

loss += self.get_loss(y_pred, yi, X=Xi, training=True)
all_y_pred.append(y_pred[:, :, None])

all_y_pred = torch.concatenate(
all_y_pred, axis=2
) # (n_samples, n_time_steps, n_events)
loss.backward()

return {
"loss": loss,
"y_pred": all_y_pred,
Expand All @@ -260,12 +232,16 @@ def validation_step(self, batch, **fit_params):
Xi, yi = unpack_data(batch)
loss = 0
all_y_pred = []
event_multiclass = yi["event"].copy()
with torch.no_grad():
for event_of_interest in range(1, self.n_events + 1):
yi["event"] = (event_multiclass == event_of_interest).long()
fit_params["event_of_interest"] = event_of_interest
y_pred = self.infer(Xi, **fit_params)

loss += self.get_loss(y_pred, yi, X=Xi, training=False)
all_y_pred.append(y_pred[:, :, None])

all_y_pred = torch.concatenate(
all_y_pred, axis=2
) # (n_samples, n_time_steps, n_events)
Expand Down Expand Up @@ -358,18 +334,19 @@ def __init__(
self,
init_range=0.02,
# BertEmbedding
n_numerical_features=1,
vocab_size=8,
n_numerical_features=1, # *
vocab_size=8, # *
hidden_size=16,
layer_norm_eps=1e-12,
hidden_dropout_prob=0.0,
initializer_range=0.02,
# BertEncoder
num_hidden_layers=3,
# BertCLS
intermediate_size=64,
n_events=1,
n_features_in=1,
n_features_out=1,
n_events=1, # *
n_features_in=1, # *
n_features_out=1, # *
):
super().__init__()
self.init_range = init_range
Expand All @@ -379,6 +356,7 @@ def __init__(
hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
initializer_range=initializer_range,
)
self.encoder = BertEncoder(num_hidden_layers=num_hidden_layers)
self.cls = BertCLSMulti(
Expand Down

0 comments on commit bc57614

Please sign in to comment.