diff --git a/scripts/train_cl.py b/scripts/train_cl.py index 25cd497..2efbf14 100644 --- a/scripts/train_cl.py +++ b/scripts/train_cl.py @@ -171,7 +171,7 @@ def main_process(rank: int, world_size: int, args): else: train_epoch(args.activate_wandb, args.model_config.epochs, epoch, pre_train_dataloader, model, optimizer, criterion, device, open_clip_ver=open_clip_ver, rank=rank) - if epoch % args.model_config.evaluation_period == 0 or epoch == args.model_config.epochs - 1 and rank == 0: + if (epoch % args.model_config.evaluation_period == 0 or epoch == args.model_config.epochs - 1) and rank == 0 and epoch != 0: if hasattr(args.model_config, 'dataset') and args.model_config.dataset == "INSECT": acc_dict, pred_dict = eval_phase(model, device, insect_train_dataloader_for_key, insect_val_dataloader, insect_test_seen_dataloader, insect_test_unseen_dataloader, k_list)