Skip to content

Commit

Permalink
Fixed Trainer Initilization Parameter Order
Browse files Browse the repository at this point in the history
  • Loading branch information
iSiddharth20 committed Dec 27, 2023
1 parent 2a99816 commit abc695b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
# Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM
os.makedirs('../Models/Method1', exist_ok=True) # Creating Directory for Model Saving
model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth'
trainer_autoencoder_baseline = Trainer(model=model_autoencoder, loss_function=loss_mse, optimizer=torch.optim.Adam(model_autoencoder.parameters()), model_save_path=model_save_path_ae)
trainer_autoencoder_baseline = Trainer(model=model_autoencoder, loss_function=loss_mse, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
print('Baseline AutoEncoder Trainer Initialized.')
model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth'
lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)
Expand All @@ -78,7 +78,7 @@ def main():
# Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM
os.makedirs('../Models/Method2', exist_ok=True) # Creating Directory for Model Saving
model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth'
trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, optimizer=torch.optim.Adam(model_autoencoder.parameters()), model_save_path=model_save_path_ae)
trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
print('Method-2 AutoEncoder Trainer Initialized.')
print('Method-2 LSTM == Method-1 LSTM')

Expand Down

0 comments on commit abc695b

Please sign in to comment.