From f104cfe281cd7a848d9af76ef4c20af119da4536 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Fri, 25 Oct 2024 21:21:34 -0400 Subject: [PATCH 1/2] Use AdamW optimizer and CosineAnnealingWarmRestarts scheduler --- src/tinytopics/fit.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/tinytopics/fit.py b/src/tinytopics/fit.py index f49bf3a..f198207 100644 --- a/src/tinytopics/fit.py +++ b/src/tinytopics/fit.py @@ -3,16 +3,32 @@ from .models import NeuralPoissonNMF -def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=None): +def fit_model( + X, + k, + num_epochs=200, + batch_size=16, + base_lr=0.01, + max_lr=0.05, + T_0=20, + T_mult=1, + weight_decay=1e-5, + device=None, +): """ - Fit topic model via sum-to-one constrained neural Poisson NMF using batch gradient descent. + Fit topic model using sum-to-one constrained neural Poisson NMF, + optimized with AdamW and a cosine annealing with warm restarts scheduler. Args: X (torch.Tensor): Document-term matrix. k (int): Number of topics. - learning_rate (float, optional): Learning rate for Adam optimizer. Default is 0.001. num_epochs (int, optional): Number of training epochs. Default is 200. - batch_size (int, optional): Number of documents per batch. Default is 64. + batch_size (int, optional): Number of documents per batch. Default is 16. + base_lr (float, optional): Minimum learning rate after annealing. Default is 0.01. + max_lr (float, optional): Starting maximum learning rate. Default is 0.05. + T_0 (int, optional): Number of epochs until the first restart. Default is 20. + T_mult (int, optional): Factor by which the restart interval increases after each restart. Default is 1. + weight_decay (float, optional): Weight decay for the AdamW optimizer. Default is 1e-5. device (torch.device, optional): Device to run the training on. Defaults to CUDA if available, otherwise CPU. Returns: @@ -24,19 +40,20 @@ def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=N n, m = X.shape model = NeuralPoissonNMF(n, m, k, device=device) - - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "min", patience=5, factor=0.5 + optimizer = torch.optim.AdamW( + model.parameters(), lr=max_lr, weight_decay=weight_decay + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=T_0, T_mult=T_mult, eta_min=base_lr ) - num_batches = n // batch_size losses = [] with tqdm(total=num_epochs, desc="Training Progress") as pbar: for epoch in range(num_epochs): permutation = torch.randperm(n, device=device) epoch_loss = 0.0 + num_batches = n // batch_size for i in range(num_batches): indices = permutation[i * batch_size : (i + 1) * batch_size] @@ -46,13 +63,15 @@ def fit_model(X, k, learning_rate=0.001, num_epochs=200, batch_size=64, device=N X_reconstructed = model(indices) loss = poisson_nmf_loss(batch_X, X_reconstructed) loss.backward() + optimizer.step() + # Update per batch for cosine annealing with restarts + scheduler.step(epoch + i / num_batches) epoch_loss += loss.item() epoch_loss /= num_batches losses.append(epoch_loss) - scheduler.step(epoch_loss) pbar.set_postfix({"Loss": f"{epoch_loss:.4f}"}) pbar.update(1) From bd503384c6d3c65253cf5f2e59e7697e9d5cc26a Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Fri, 25 Oct 2024 21:24:56 -0400 Subject: [PATCH 2/2] Remove obsolete learning_rate parameter --- docs/articles/benchmark.md | 3 +-- docs/articles/benchmark.qmd | 3 +-- docs/articles/get-started.md | 2 +- docs/articles/get-started.qmd | 2 +- examples/benchmark.py | 3 +-- examples/get-started.py | 2 +- 6 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/articles/benchmark.md b/docs/articles/benchmark.md index 14aa203..d01c7ba 100644 --- a/docs/articles/benchmark.md +++ b/docs/articles/benchmark.md @@ -61,7 +61,6 @@ Define parameter grids: n_values = [1000, 5000] # Number of documents m_values = [500, 1000, 5000, 10000] # Vocabulary size k_values = [10, 50, 100] # Number of topics -learning_rate = 0.01 avg_doc_length = 256 * 256 ``` @@ -72,7 +71,7 @@ benchmark_results = pd.DataFrame() def benchmark(X, k, device): start_time = time.time() - model, losses = fit_model(X, k, learning_rate=learning_rate, device=device) + model, losses = fit_model(X, k, device=device) elapsed_time = time.time() - start_time return elapsed_time diff --git a/docs/articles/benchmark.qmd b/docs/articles/benchmark.qmd index cff9e04..b271c95 100644 --- a/docs/articles/benchmark.qmd +++ b/docs/articles/benchmark.qmd @@ -62,7 +62,6 @@ Define parameter grids: n_values = [1000, 5000] # Number of documents m_values = [500, 1000, 5000, 10000] # Vocabulary size k_values = [10, 50, 100] # Number of topics -learning_rate = 0.01 avg_doc_length = 256 * 256 ``` @@ -73,7 +72,7 @@ benchmark_results = pd.DataFrame() def benchmark(X, k, device): start_time = time.time() - model, losses = fit_model(X, k, learning_rate=learning_rate, device=device) + model, losses = fit_model(X, k, device=device) elapsed_time = time.time() - start_time return elapsed_time diff --git a/docs/articles/get-started.md b/docs/articles/get-started.md index 48a67ab..57bb1d1 100644 --- a/docs/articles/get-started.md +++ b/docs/articles/get-started.md @@ -72,7 +72,7 @@ Fit the topic model and plot the loss curve. There will be a progress bar. ``` python -model, losses = fit_model(X, k, learning_rate=0.01) +model, losses = fit_model(X, k) plot_loss(losses, output_file="loss.png") ``` diff --git a/docs/articles/get-started.qmd b/docs/articles/get-started.qmd index e815dfa..0b4cf40 100644 --- a/docs/articles/get-started.qmd +++ b/docs/articles/get-started.qmd @@ -72,7 +72,7 @@ X, true_L, true_F = generate_synthetic_data(n, m, k, avg_doc_length=256 * 256) Fit the topic model and plot the loss curve. There will be a progress bar. ```{python} -model, losses = fit_model(X, k, learning_rate=0.01) +model, losses = fit_model(X, k) plot_loss(losses, output_file="loss.png") ``` diff --git a/examples/benchmark.py b/examples/benchmark.py index 56256db..2cd2b6e 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -12,7 +12,6 @@ n_values = [1000, 5000] # Number of documents m_values = [500, 1000, 5000, 10000] # Vocabulary size k_values = [10, 50, 100] # Number of topics -learning_rate = 0.01 avg_doc_length = 256 * 256 @@ -21,7 +20,7 @@ def benchmark(X, k, device): start_time = time.time() - model, losses = fit_model(X, k, learning_rate=learning_rate, device=device) + model, losses = fit_model(X, k, device=device) elapsed_time = time.time() - start_time return elapsed_time diff --git a/examples/get-started.py b/examples/get-started.py index bc1ec16..51b4cf0 100644 --- a/examples/get-started.py +++ b/examples/get-started.py @@ -15,7 +15,7 @@ X, true_L, true_F = generate_synthetic_data(n, m, k, avg_doc_length=256 * 256) -model, losses = fit_model(X, k, learning_rate=0.01) +model, losses = fit_model(X, k) plot_loss(losses, output_file="loss.png")