Skip to content

Commit

Permalink
Merge pull request #3 from nanxstats/scheduler
Browse files Browse the repository at this point in the history
Use AdamW optimizer and CosineAnnealingWarmRestarts scheduler
  • Loading branch information
nanxstats authored Oct 26, 2024
2 parents e50ae90 + bd50338 commit 622b19a
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 19 deletions.
3 changes: 1 addition & 2 deletions docs/articles/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ Define parameter grids:
n_values = [1000, 5000] # Number of documents
m_values = [500, 1000, 5000, 10000] # Vocabulary size
k_values = [10, 50, 100] # Number of topics
learning_rate = 0.01
avg_doc_length = 256 * 256
```

Expand All @@ -72,7 +71,7 @@ benchmark_results = pd.DataFrame()

def benchmark(X, k, device):
start_time = time.time()
model, losses = fit_model(X, k, learning_rate=learning_rate, device=device)
model, losses = fit_model(X, k, device=device)
elapsed_time = time.time() - start_time

return elapsed_time
Expand Down
3 changes: 1 addition & 2 deletions docs/articles/benchmark.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ Define parameter grids:
n_values = [1000, 5000] # Number of documents
m_values = [500, 1000, 5000, 10000] # Vocabulary size
k_values = [10, 50, 100] # Number of topics
learning_rate = 0.01
avg_doc_length = 256 * 256
```

Expand All @@ -73,7 +72,7 @@ benchmark_results = pd.DataFrame()
def benchmark(X, k, device):
start_time = time.time()
model, losses = fit_model(X, k, learning_rate=learning_rate, device=device)
model, losses = fit_model(X, k, device=device)
elapsed_time = time.time() - start_time
return elapsed_time
Expand Down
2 changes: 1 addition & 1 deletion docs/articles/get-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Fit the topic model and plot the loss curve. There will be a progress
bar.

``` python
model, losses = fit_model(X, k, learning_rate=0.01)
model, losses = fit_model(X, k)

plot_loss(losses, output_file="loss.png")
```
Expand Down
2 changes: 1 addition & 1 deletion docs/articles/get-started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ X, true_L, true_F = generate_synthetic_data(n, m, k, avg_doc_length=256 * 256)
Fit the topic model and plot the loss curve. There will be a progress bar.

```{python}
model, losses = fit_model(X, k, learning_rate=0.01)
model, losses = fit_model(X, k)
plot_loss(losses, output_file="loss.png")
```
Expand Down
3 changes: 1 addition & 2 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
n_values = [1000, 5000] # Number of documents
m_values = [500, 1000, 5000, 10000] # Vocabulary size
k_values = [10, 50, 100] # Number of topics
learning_rate = 0.01
avg_doc_length = 256 * 256


Expand All @@ -21,7 +20,7 @@

def benchmark(X, k, device):
start_time = time.time()
model, losses = fit_model(X, k, learning_rate=learning_rate, device=device)
model, losses = fit_model(X, k, device=device)
elapsed_time = time.time() - start_time

return elapsed_time
Expand Down
2 changes: 1 addition & 1 deletion examples/get-started.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
X, true_L, true_F = generate_synthetic_data(n, m, k, avg_doc_length=256 * 256)


model, losses = fit_model(X, k, learning_rate=0.01)
model, losses = fit_model(X, k)

plot_loss(losses, output_file="loss.png")

Expand Down
39 changes: 29 additions & 10 deletions src/tinytopics/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,32 @@
from .models import NeuralPoissonNMF


def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=None):
def fit_model(
X,
k,
num_epochs=200,
batch_size=16,
base_lr=0.01,
max_lr=0.05,
T_0=20,
T_mult=1,
weight_decay=1e-5,
device=None,
):
"""
Fit topic model via sum-to-one constrained neural Poisson NMF using batch gradient descent.
Fit topic model using sum-to-one constrained neural Poisson NMF,
optimized with AdamW and a cosine annealing with warm restarts scheduler.
Args:
X (torch.Tensor): Document-term matrix.
k (int): Number of topics.
learning_rate (float, optional): Learning rate for Adam optimizer. Default is 0.001.
num_epochs (int, optional): Number of training epochs. Default is 200.
batch_size (int, optional): Number of documents per batch. Default is 64.
batch_size (int, optional): Number of documents per batch. Default is 16.
base_lr (float, optional): Minimum learning rate after annealing. Default is 0.01.
max_lr (float, optional): Starting maximum learning rate. Default is 0.05.
T_0 (int, optional): Number of epochs until the first restart. Default is 20.
T_mult (int, optional): Factor by which the restart interval increases after each restart. Default is 1.
weight_decay (float, optional): Weight decay for the AdamW optimizer. Default is 1e-5.
device (torch.device, optional): Device to run the training on. Defaults to CUDA if available, otherwise CPU.
Returns:
Expand All @@ -24,19 +40,20 @@ def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=N
n, m = X.shape

model = NeuralPoissonNMF(n, m, k, device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, "min", patience=5, factor=0.5
optimizer = torch.optim.AdamW(
model.parameters(), lr=max_lr, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=T_0, T_mult=T_mult, eta_min=base_lr
)

num_batches = n // batch_size
losses = []

with tqdm(total=num_epochs, desc="Training Progress") as pbar:
for epoch in range(num_epochs):
permutation = torch.randperm(n, device=device)
epoch_loss = 0.0
num_batches = n // batch_size

for i in range(num_batches):
indices = permutation[i * batch_size : (i + 1) * batch_size]
Expand All @@ -46,13 +63,15 @@ def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=N
X_reconstructed = model(indices)
loss = poisson_nmf_loss(batch_X, X_reconstructed)
loss.backward()

optimizer.step()
# Update per batch for cosine annealing with restarts
scheduler.step(epoch + i / num_batches)

epoch_loss += loss.item()

epoch_loss /= num_batches
losses.append(epoch_loss)
scheduler.step(epoch_loss)
pbar.set_postfix({"Loss": f"{epoch_loss:.4f}"})
pbar.update(1)

Expand Down

0 comments on commit 622b19a

Please sign in to comment.