Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading model #168

Open
wants to merge 13 commits into
base: Develop_copy
Choose a base branch
from
Open

Loading model #168

wants to merge 13 commits into from

Conversation

shakibamrd
Copy link
Collaborator

saves optimizer.get_checkpointables() return values in checkpoint, and is also able to load them.
changed:
trainer.py
most of optimizer.py.

Please review the use of _set_checkpoint() in trainer.py which is before calling optimizer.before_training(). In the metaclass of the optimizers it is mentioned that _set_checkpoint() should be used c

Copy link
Contributor

@gierle gierle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to reproduce the issue of #153 and its solution 👍🏻

batch_size: 64
learning_rate: 0.025
learning_rate_min: 0.001
momentum: 0.9
weight_decay: 0.0003
epochs: 50
epochs: 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change only used for testing? 50 epochs is also stated in the paper

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please revert to 50.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted the epochs to its original value.

@@ -133,6 +133,12 @@ def new_epoch(self, epoch):
"""
Just log the architecture weights.
"""
# print("=====================================")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this code added? Can it be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra code which was used for debugging is removed.


fidelity: 200

# GDAS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make yaml files generally more readable, should focus only on specific optimizer settings @Neonkraft ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The darts_defualts.yaml was reverted to the format of the Develop_copy branch.

@@ -146,7 +147,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

self.train_loss.update(float(train_loss.detach().cpu()))
self.val_loss.update(float(val_loss.detach().cpu()))

# break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Please remove.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'break' was used for debugging and got removed in the new commit.


def set_checkpointables(self, architectural_weights):
"""
would set the objects saved in the checkpoint during last phase of training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since other functions also include this, maybe add a description of parameters and return types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, agreed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the Args has been specified. This function has no return value.

trainer.search(resume_from="")
trainer.evaluate(resume_from="", dataset_api=dataset_api)

# trainer.search(resume_from="/home/moradias/nas-fix/run/nasbench201/cifar10/darts/97/search/model_0000002.pth")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is most likely for testing purposes, it should be removed

'transbench101_macro': TransBench101SearchSpaceMacro(),
'asr': NasBenchASRSearchSpace(),
'nasbench201': NasBench201SearchSpace(n_classes=config.n_classes),
# 'nasbench301': NasBench301SearchSpace(n_classes=config.n_classes, auxiliary=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we uncomment NB301 here, so it can be used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Also, why remove transbench101_macro and asr? @shakibamrd

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The runner had some parts removed during of debugging. With the new commit it is reverted back to its original form.

@@ -6,7 +6,7 @@
from naslib.search_spaces.core.graph import Graph, EdgeData
from naslib.search_spaces.core import primitives as ops

from ..nasbench301.graph import _truncate_input_edges
# from ..nasbench301.graph import _truncate_input_edges
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the unused input be removed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this fix change stuff in the graph of Simple Cell?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_truncate_input_edge is not called or used in Simple cell that's why I removed it. I have put it back in the new commit.

elif config.dataset == 'ImageNet16-120':
config.n_classes = 120
else:
config.n_classes = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this conflict with e.g. custom datasets?
We could raise an exception or create a warning instead

Suggested change
config.n_classes = 10
raise AttributeError
Suggested change
config.n_classes = 10
import warnings
warnings.warn("Number of classes was not set. Default 10 is set.")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions added the warning.

@@ -138,6 +138,7 @@ def single_evaluate(self, test_data, zc_api):
logger.info("Querying the predictor")
query_time_start = time.time()

# TODO: shouldn't mode="val" be passed?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants