Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sped up MetricsCB and ProgressCB #18

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

PiotrCzapla
Copy link

As mentioned on the forum with minimal changes to the ProgressCB and MetricsCB we can speed up training significantly allowing for the following batch to be prepared while the first one is being processed on GPU. The speed up is noticeable when data loading is fast or model is slow enough to hide data loading latency.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

self.all_metrics = copy(metrics)
self.all_metrics['loss'] = self.loss = Mean()

self.all_metrics['loss'] = self.loss = Mean(device='cpu' if 'mps' in device else device)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS does not support doubles, and Mean fails when placed on device as it weights are double.

@PiotrCzapla
Copy link
Author

PiotrCzapla commented Feb 28, 2023

The caching code is simple but I kept it away from this PR as it was not necessary as I could get a way with

dls.train = list(dls.train)
dls.valid = list(dls.valid)

Which is awesome as it reinforce how flexible miniai is.
The code to cache dataset in memory is more complex, and it is not necessary if your model is large enough, but it is game changer on MPS, where multiprocessing works poorly. I'm not sure where to place it though.

#| export
def _with_features(ds):
    setattr((l:=fc.L(ds)), 'features', ds.features)
    return l 
class CachedDS(dict):  
    """Dict that does not print it's content, letting us inspect the dataset in Jupyter in reasonable time"""
    def __repr__(self): return "{ "+", ".join([f'{k}: (#{len(v)})' for k,v in self.items()])+" }"
    def __str__(self): return repr(self)
def cache_dataset_as_dict(dd): return CachedDS({dsn: _with_features(ds) for dsn,ds in dd.items()})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant