-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading model #168
base: Develop_copy
Are you sure you want to change the base?
Loading model #168
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to reproduce the issue of #153 and its solution 👍🏻
naslib/defaults/darts_defaults.yaml
Outdated
batch_size: 64 | ||
learning_rate: 0.025 | ||
learning_rate_min: 0.001 | ||
momentum: 0.9 | ||
weight_decay: 0.0003 | ||
epochs: 50 | ||
epochs: 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this change only used for testing? 50 epochs is also stated in the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please revert to 50.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted the epochs to its original value.
@@ -133,6 +133,12 @@ def new_epoch(self, epoch): | |||
""" | |||
Just log the architecture weights. | |||
""" | |||
# print("=====================================") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this code added? Can it be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra code which was used for debugging is removed.
|
||
fidelity: 200 | ||
|
||
# GDAS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to make yaml files generally more readable, should focus only on specific optimizer settings @Neonkraft ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The darts_defualts.yaml was reverted to the format of the Develop_copy branch.
naslib/defaults/trainer.py
Outdated
@@ -146,7 +147,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |||
|
|||
self.train_loss.update(float(train_loss.detach().cpu())) | |||
self.val_loss.update(float(val_loss.detach().cpu())) | |||
|
|||
# break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'break' was used for debugging and got removed in the new commit.
|
||
def set_checkpointables(self, architectural_weights): | ||
""" | ||
would set the objects saved in the checkpoint during last phase of training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since other functions also include this, maybe add a description of parameters and return types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, agreed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of the Args has been specified. This function has no return value.
naslib/runners/nas/runner.py
Outdated
trainer.search(resume_from="") | ||
trainer.evaluate(resume_from="", dataset_api=dataset_api) | ||
|
||
# trainer.search(resume_from="/home/moradias/nas-fix/run/nasbench201/cifar10/darts/97/search/model_0000002.pth") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is most likely for testing purposes, it should be removed
naslib/runners/nas/runner.py
Outdated
'transbench101_macro': TransBench101SearchSpaceMacro(), | ||
'asr': NasBenchASRSearchSpace(), | ||
'nasbench201': NasBench201SearchSpace(n_classes=config.n_classes), | ||
# 'nasbench301': NasBench301SearchSpace(n_classes=config.n_classes, auxiliary=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we uncomment NB301 here, so it can be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Also, why remove transbench101_macro
and asr
? @shakibamrd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The runner had some parts removed during of debugging. With the new commit it is reverted back to its original form.
@@ -6,7 +6,7 @@ | |||
from naslib.search_spaces.core.graph import Graph, EdgeData | |||
from naslib.search_spaces.core import primitives as ops | |||
|
|||
from ..nasbench301.graph import _truncate_input_edges | |||
# from ..nasbench301.graph import _truncate_input_edges |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the unused input be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this fix change stuff in the graph of Simple Cell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_truncate_input_edge is not called or used in Simple cell that's why I removed it. I have put it back in the new commit.
elif config.dataset == 'ImageNet16-120': | ||
config.n_classes = 120 | ||
else: | ||
config.n_classes = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this conflict with e.g. custom datasets?
We could raise an exception or create a warning instead
config.n_classes = 10 | |
raise AttributeError |
config.n_classes = 10 | |
import warnings | |
warnings.warn("Number of classes was not set. Default 10 is set.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions added the warning.
@@ -138,6 +138,7 @@ def single_evaluate(self, test_data, zc_api): | |||
logger.info("Querying the predictor") | |||
query_time_start = time.time() | |||
|
|||
# TODO: shouldn't mode="val" be passed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense for me.
saves optimizer.get_checkpointables() return values in checkpoint, and is also able to load them.
changed:
trainer.py
most of optimizer.py.
Please review the use of _set_checkpoint() in trainer.py which is before calling optimizer.before_training(). In the metaclass of the optimizers it is mentioned that _set_checkpoint() should be used c