Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bm/load_from_checkpoint #116

Merged
merged 95 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
21d4246
implement test behavior with no test inputs
BenjaminMidtvedt Apr 22, 2024
e3e5089
docs
BenjaminMidtvedt Apr 29, 2024
12bd38e
Refactor code in deeplay/activelearning/data.py, deeplay/activelearni…
BenjaminMidtvedt Apr 29, 2024
2cf1cd9
clear config before multi
BenjaminMidtvedt May 2, 2024
f9dee3b
update to use absolute imports
BenjaminMidtvedt May 2, 2024
b1e18f4
Refactor import statements in test_selectors.py
BenjaminMidtvedt May 2, 2024
3c04e1e
Bm/fix-config-nested-new
BenjaminMidtvedt May 2, 2024
16758c5
remove test file
BenjaminMidtvedt May 2, 2024
e30a6f7
Fix issue with clearing configuration before creating multiple blocks
BenjaminMidtvedt May 2, 2024
6714c65
Fix issue with configuring upsample in Conv2dBlock
BenjaminMidtvedt May 2, 2024
08ac96c
Re-enable test_strided_multi in test_conv.py
BenjaminMidtvedt May 3, 2024
95cf49f
add stubs
BenjaminMidtvedt May 6, 2024
2279397
Refactor residual function in Conv2dBlock to support flexible layer o…
BenjaminMidtvedt May 6, 2024
5697e66
Refactor available_styles method in DeeplayModule to use classmethod
BenjaminMidtvedt May 6, 2024
e329512
Implement script to create stubs with style typing
BenjaminMidtvedt May 6, 2024
af7d358
Refactor Conv2dBlock and related functions in conv2d.pyi
BenjaminMidtvedt May 6, 2024
2bea2ed
Remove publish script from package.json
BenjaminMidtvedt May 6, 2024
21fb687
Add .gitignore entry for package.json
BenjaminMidtvedt May 6, 2024
139e7ef
add docs
BenjaminMidtvedt May 6, 2024
a449dd1
delete data
BenjaminMidtvedt May 6, 2024
6af3f43
Update documentation and style guide
BenjaminMidtvedt May 6, 2024
2af03bf
Update naming conventions, imports, documentation, and testing guidel…
BenjaminMidtvedt May 7, 2024
dfa89e1
Refactor Conv2dBlock and related functions in conv2d.pyi
BenjaminMidtvedt May 7, 2024
f1a03ab
Refactor Conv2dBlock and related functions in conv2d.pyi
BenjaminMidtvedt May 8, 2024
7001e42
Refactor Conv2dBlock and related functions in conv2d.pyi
BenjaminMidtvedt May 10, 2024
10a8fe8
Merge branch 'develop' into bm/docs
BenjaminMidtvedt May 15, 2024
f10cb09
§
BenjaminMidtvedt May 16, 2024
1dee8a7
u
BenjaminMidtvedt May 16, 2024
346ee21
Merge branch 'bm/docs' of https://github.com/softmatterlab/DeepTorch …
BenjaminMidtvedt May 17, 2024
28c938b
baseline shape compute
BenjaminMidtvedt May 20, 2024
bc3b95d
docsv1
BenjaminMidtvedt May 20, 2024
edb1232
chore: Update expected_input_shape property in LinearBlock class
BenjaminMidtvedt May 20, 2024
a9a310e
Refactor DeeplayModule to use torch.no_grad() for args and kwargs inv…
BenjaminMidtvedt May 20, 2024
935bc4d
chore: Refactor BaseBlock class to improve code readability and maint…
BenjaminMidtvedt May 20, 2024
983f013
Refactor LinearBlock class to improve default configuration and add m…
BenjaminMidtvedt May 20, 2024
0b2fa5c
Refactor Sequence1dBlock class to improve default normalization and a…
BenjaminMidtvedt May 20, 2024
bc307e3
Refactor Layer class to warn when forward path is called with gradien…
BenjaminMidtvedt May 20, 2024
2b2f208
Refactor resnet18.py to improve code readability and maintainability
BenjaminMidtvedt May 20, 2024
dc4a79e
Refactor Conv2dBlock to fix layer configuration and add missing asser…
BenjaminMidtvedt May 20, 2024
eefe1cc
Refactor Conv2dBlock to fix layer configuration and add missing asser…
BenjaminMidtvedt May 20, 2024
d69c9eb
Refactor Conv2dBlock to configure padding based on kernel size
BenjaminMidtvedt May 20, 2024
55da145
Refactor Conv2dBlock to configure padding based on kernel size
BenjaminMidtvedt May 20, 2024
83b86e5
Merge branch 'develop' into bm/block-attributes
BenjaminMidtvedt May 21, 2024
4e9b9b5
remove exception on RuntimeError
BenjaminMidtvedt May 22, 2024
85f2721
del docs
BenjaminMidtvedt May 22, 2024
deb0e10
Refactor DeeplayModule new method to set detach parameter to True by …
BenjaminMidtvedt May 22, 2024
6b937ba
Refactor RecurrentModel to use super() for forward pass
BenjaminMidtvedt May 22, 2024
add9942
Refactor DeeplayModule to set detach parameter to True by default
BenjaminMidtvedt May 22, 2024
0215731
Refactor DeeplayModule to return computed values in configure method
BenjaminMidtvedt May 22, 2024
469e052
Refactor Layer to undo configuration of computed values on exception
BenjaminMidtvedt May 22, 2024
63bd9f9
Refactor RNN module to include batch_first and return_cell_state para…
BenjaminMidtvedt May 23, 2024
ec8a77d
Refactor RecurrentModel to not return cell state by default
BenjaminMidtvedt May 23, 2024
f37b717
Correctly handle cell state
BenjaminMidtvedt May 23, 2024
d014c78
chore: Fix torch MPS issue with indexed tensors in DeeplayModule
BenjaminMidtvedt May 23, 2024
efe4b0d
chore: Fix torch MPS issue with indexed tensors in DeeplayModule
BenjaminMidtvedt May 23, 2024
b6fc72f
chore: Update Python and OS versions in CI workflow
BenjaminMidtvedt May 23, 2024
8b6741d
chore: Update Python and OS versions in CI workflow
BenjaminMidtvedt May 23, 2024
c652850
undo CI test
BenjaminMidtvedt May 23, 2024
fcd1806
Update module.py
giovannivolpe May 24, 2024
7e95865
Add optional normalization parameter to BaseBlock.normalized()
BenjaminMidtvedt May 25, 2024
31ba2af
Add load from pickled checkpoint in metaclass
BenjaminMidtvedt Jun 20, 2024
10063bd
Refactor Application build method to store hparams and modules
BenjaminMidtvedt Jun 20, 2024
39bfa5d
Add reduction method for pickle
BenjaminMidtvedt Jun 20, 2024
258a847
feat: Add default shortcut function to BaseBlock
BenjaminMidtvedt Jun 20, 2024
259c019
refactor: Improve ResNet style block in resnet18.py
BenjaminMidtvedt Jun 20, 2024
8fd4980
refactor: Add check for DeeplayModule in Application class
BenjaminMidtvedt Jun 20, 2024
42cade2
remove reduction method
BenjaminMidtvedt Jun 20, 2024
33f5ce3
Pickle using dill instead of pickle
BenjaminMidtvedt Jun 20, 2024
a8f0bff
chore: Update requirements.txt with dill package
BenjaminMidtvedt Jun 20, 2024
0962c3f
feat: Update requirements.txt with numpy>=1.24.1
BenjaminMidtvedt Jun 21, 2024
a834343
chore: Update numpy version in requirements.txt
BenjaminMidtvedt Jun 21, 2024
dafe7e8
add a cleanup method before create
BenjaminMidtvedt Jun 23, 2024
26847fe
add test case for double create bug
BenjaminMidtvedt Jun 23, 2024
c48a79c
remove unused import
BenjaminMidtvedt Jun 23, 2024
2f4dd36
Reduce use of deepcopy (for speed)
BenjaminMidtvedt Jun 23, 2024
56fd135
Reduce use of deepcopy (for speed)
BenjaminMidtvedt Jun 23, 2024
6fc049a
chore: Remove print statement in layer.py
BenjaminMidtvedt Jun 23, 2024
da18c7e
refactor: Add non-class type arguments to External class initialization
BenjaminMidtvedt Jun 23, 2024
18a3236
Refactor import statements in deeplay.blocks.ls.py and stylestubgen.py
BenjaminMidtvedt Jun 23, 2024
9136a13
refactor: Add stateful decorator for methods modifying object state
BenjaminMidtvedt Jun 23, 2024
2f0ce64
refactor: Add initialization steps for object attributes in meta.py
BenjaminMidtvedt Jun 23, 2024
7be1b07
refactor new method to not use deepcopy
BenjaminMidtvedt Jun 23, 2024
561c96d
refactor: Remove unused import and regex module from deeplay/module.py
BenjaminMidtvedt Jun 23, 2024
f34e043
Merge branch 'bm/load_from_checkpoint' of https://github.com/softmatt…
BenjaminMidtvedt Jun 23, 2024
4d2949a
refactor: Update ExtendedConstructorMeta to use _module_state instead…
BenjaminMidtvedt Jun 23, 2024
616f783
refactor: Remove commented out code in deeplay/module.py
BenjaminMidtvedt Jun 24, 2024
79f279c
refactor: Add typing_extensions.Self import to stylestubgen.py
BenjaminMidtvedt Jun 24, 2024
9f5b19b
Refactor test_encdec.py and test_gnn.py to use gnn.create() instead o…
BenjaminMidtvedt Jun 24, 2024
e535855
add application tests
BenjaminMidtvedt Jun 24, 2024
4e8d470
refactor: Update log_metrics method to preprocess y_hat and y before …
BenjaminMidtvedt Jun 24, 2024
daf2dd8
refactor: Remove unused imports in deeplay/tests/applications/base.py
BenjaminMidtvedt Jun 24, 2024
a004c7e
del notebook
BenjaminMidtvedt Jun 25, 2024
b80dbc6
Merge branch 'develop' into bm/load_from_checkpoint
BenjaminMidtvedt Jun 25, 2024
26c21eb
del notebook
BenjaminMidtvedt Jun 25, 2024
9e24c1b
refactor: Add _non_classtype_args attribute to External class
BenjaminMidtvedt Jun 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 0 additions & 300 deletions Building a block.ipynb

This file was deleted.

55 changes: 54 additions & 1 deletion deeplay/applications/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy

import logging
from pickle import PicklingError
from typing import (
Callable,
Dict,
Expand All @@ -13,6 +14,7 @@
Union,
Any,
)
from warnings import warn

import lightning as L
import matplotlib.pyplot as plt
Expand All @@ -28,6 +30,7 @@
import deeplay as dl
from deeplay import DeeplayModule, Optimizer
from deeplay.callbacks import RichProgressBar, LogHistory
import dill

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)
Expand Down Expand Up @@ -326,8 +329,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
def log_metrics(
self, kind: Literal["train", "val", "test"], y_hat, y, **logger_kwargs
):
ys = self.metrics_preprocess(y_hat, y)

metrics: tm.MetricCollection = getattr(self, f"{kind}_metrics")
metrics(y_hat, y)
metrics(*ys)

for name, metric in metrics.items():
self.log(
Expand All @@ -336,6 +341,9 @@ def log_metrics(
**logger_kwargs,
)

def metrics_preprocess(self, y_hat, y) -> Tuple[torch.Tensor, torch.Tensor]:
return y_hat, y

@L.LightningModule.trainer.setter
def trainer(self, trainer):
# Call the original setter
Expand Down Expand Up @@ -495,3 +503,48 @@ def log(self, name, value, **kwargs):
kwargs.update({"batch_size": self._current_batch_size})

super().log(name, value, **kwargs)

def build(self, *args, **kwargs):
if self.root_module is self:
try:
self._store_hparams(*args, **kwargs)
except PicklingError:
warn("Could not store hparams, checkpointing might not be available.")
self.__construct__()

return super().build(*args, **kwargs)

def _store_hparams(self, *args, **kwargs):
import pickle

for name, module in self.named_modules():
if not isinstance(module, DeeplayModule):
continue
self._user_config.remove_derived_configurations(module.tags)
self.__parent_hooks__ = {
"before_build": [],
"after_build": [],
"after_init": [],
}
self.__constructor_hooks__ = {
"before_build": [],
"after_build": [],
"after_init": [],
}
self._modules.clear()

_pickled_application = dill.dumps(self)
self._set_hparams(
{
"__from_ckpt_application": _pickled_application,
"__build_args": args,
"__build_kwargs": kwargs,
}
)

# restore the application
self.__construct__()

# @classmethod
# def load_from_checkpoint(cls, checkpoint_path: str | Path | np.IO, map_location: torch.device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[torch.device | str | int, torch.device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) -> Self:
# return super().load_from_checkpoint(checkpoint_path, map_location, hparams_file, strict, **kwargs)
Loading
Loading