An implementation of the OpenAI 'Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets' paper in PyTorch.
- Clone the repo and cd into it:
git clone https://github.com/danielmamay/grokking.git cd grokking
- Use Python 3.9 or later:
conda create -n grokking python=3.9 conda activate grokking pip install -r requirements.txt
The project uses Weights & Biases to keep track of experiments. Run wandb login
to use the online dashboard, or wandb offline
to store the data on your local machine.
-
To run a single experiment using the CLI:
wandb login python grokking/cli.py
-
To run a grid search using W&B Sweeps:
wandb sweep sweep.yaml wandb agent {entity}/grokking/{sweep_id}
Code:
Paper:
Figures:
Randomness:
- Originates from random split of the data into train and validate
- Currently removed by adding
torch.manual_seed(0)
to data.py
TODOs:
- is the late fitting from bandwidth of frequencies needed for sawtooth with 97 period?
- How does fitting depend on the prime factor?