From af6f3d94c50e752ea6febd9ee0a8a576b8a7255c Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 16:04:03 -0500 Subject: [PATCH 1/6] Run rye sync --update-all to upgrade all deps --- requirements-dev.lock | 39 +++++++++++++++++++-------------------- requirements.lock | 10 +++++----- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/requirements-dev.lock b/requirements-dev.lock index 428fa1a..3b577b4 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -33,7 +33,7 @@ babel==2.16.0 # via mkdocs-material beautifulsoup4==4.12.3 # via nbconvert -bleach==6.1.0 +bleach==6.2.0 # via nbconvert certifi==2024.8.30 # via httpcore @@ -56,7 +56,7 @@ contourpy==1.3.0 # via matplotlib cycler==0.12.1 # via matplotlib -debugpy==1.8.7 +debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython @@ -95,7 +95,7 @@ ipykernel==6.29.5 # via jupyter # via jupyter-console # via jupyterlab -ipython==8.28.0 +ipython==8.29.0 # via ipykernel # via ipywidgets # via jupyter-console @@ -103,7 +103,7 @@ ipywidgets==8.1.5 # via jupyter isoduration==20.11.0 # via jsonschema -jedi==0.19.1 +jedi==0.19.2 # via ipython jinja2==3.1.4 # via jupyter-server @@ -114,7 +114,7 @@ jinja2==3.1.4 # via mkdocstrings # via nbconvert # via torch -json5==0.9.25 +json5==0.9.28 # via jupyterlab-server jsonpointer==3.0.0 # via jsonschema @@ -198,10 +198,10 @@ mkdocs-autorefs==1.2.0 # via mkdocstrings-python mkdocs-get-deps==0.2.0 # via mkdocs -mkdocs-material==9.5.42 +mkdocs-material==9.5.44 mkdocs-material-extensions==1.3.1 # via mkdocs-material -mkdocstrings==0.26.2 +mkdocstrings==0.27.0 # via mkdocstrings-python mkdocstrings-python==1.12.2 mpmath==1.3.0 @@ -225,7 +225,7 @@ notebook==7.2.2 notebook-shim==0.2.4 # via jupyterlab # via notebook -numpy==2.1.2 +numpy==2.1.3 # via contourpy # via imageio # via matplotlib @@ -235,7 +235,7 @@ numpy==2.1.2 # via tinytopics overrides==7.7.0 # via jupyter-server -packaging==24.1 +packaging==24.2 # via ipykernel # via jupyter-server # via jupyterlab @@ -282,7 +282,7 @@ pygments==2.18.0 # via jupyter-console # via mkdocs-material # via nbconvert -pymdown-extensions==10.11.2 +pymdown-extensions==10.12 # via mkdocs-material # via mkdocstrings pyparsing==3.2.0 @@ -311,7 +311,7 @@ referencing==0.35.1 # via jsonschema # via jsonschema-specifications # via jupyter-events -regex==2024.9.11 +regex==2024.11.6 # via mkdocs-material requests==2.32.3 # via jupyterlab-server @@ -322,10 +322,10 @@ rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 # via jsonschema # via jupyter-events -rpds-py==0.20.0 +rpds-py==0.21.0 # via jsonschema # via referencing -ruff==0.7.0 +ruff==0.7.3 scikit-image==0.24.0 # via tinytopics scipy==1.14.1 @@ -333,12 +333,11 @@ scipy==1.14.1 # via tinytopics send2trash==1.8.3 # via jupyter-server -setuptools==75.2.0 +setuptools==75.4.0 # via jupyterlab # via torch six==1.16.0 # via asttokens - # via bleach # via python-dateutil # via rfc3339-validator sniffio==1.3.1 @@ -355,9 +354,9 @@ terminado==0.18.1 # via jupyter-server-terminals tifffile==2024.9.20 # via scikit-image -tinycss2==1.3.0 +tinycss2==1.4.0 # via nbconvert -torch==2.5.0 +torch==2.5.1 # via tinytopics tornado==6.4.1 # via ipykernel @@ -366,7 +365,7 @@ tornado==6.4.1 # via jupyterlab # via notebook # via terminado -tqdm==4.66.5 +tqdm==4.67.0 # via tinytopics traitlets==5.14.3 # via comm @@ -391,11 +390,11 @@ uri-template==1.3.0 # via jsonschema urllib3==2.2.3 # via requests -watchdog==5.0.3 +watchdog==6.0.0 # via mkdocs wcwidth==0.2.13 # via prompt-toolkit -webcolors==24.8.0 +webcolors==24.11.1 # via jsonschema webencodings==0.5.1 # via bleach diff --git a/requirements.lock b/requirements.lock index 0acb569..23633b1 100644 --- a/requirements.lock +++ b/requirements.lock @@ -37,7 +37,7 @@ mpmath==1.3.0 networkx==3.4.2 # via scikit-image # via torch -numpy==2.1.2 +numpy==2.1.3 # via contourpy # via imageio # via matplotlib @@ -45,7 +45,7 @@ numpy==2.1.2 # via scipy # via tifffile # via tinytopics -packaging==24.1 +packaging==24.2 # via lazy-loader # via matplotlib # via scikit-image @@ -62,7 +62,7 @@ scikit-image==0.24.0 scipy==1.14.1 # via scikit-image # via tinytopics -setuptools==75.2.0 +setuptools==75.4.0 # via torch six==1.16.0 # via python-dateutil @@ -70,9 +70,9 @@ sympy==1.13.1 # via torch tifffile==2024.9.20 # via scikit-image -torch==2.5.0 +torch==2.5.1 # via tinytopics -tqdm==4.66.5 +tqdm==4.67.0 # via tinytopics typing-extensions==4.12.2 # via torch From 47aa352b818541756f20753582cd6f77fb00e046 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 16:05:14 -0500 Subject: [PATCH 2/6] Use type hints and a more functional style --- src/tinytopics/colors.py | 79 +++++++++++------ src/tinytopics/fit.py | 98 +++++++++++---------- src/tinytopics/models.py | 41 +++++---- src/tinytopics/plot.py | 178 ++++++++++++++++++++------------------- src/tinytopics/utils.py | 107 ++++++++++++----------- 5 files changed, 276 insertions(+), 227 deletions(-) diff --git a/src/tinytopics/colors.py b/src/tinytopics/colors.py index 8ad75e1..6c7f990 100644 --- a/src/tinytopics/colors.py +++ b/src/tinytopics/colors.py @@ -1,10 +1,27 @@ +from typing import List, Union, Literal, overload + import numpy as np +from numpy.typing import NDArray from matplotlib import colors +from matplotlib.colors import ListedColormap from skimage import color from scipy.interpolate import make_interp_spline -def pal_tinytopics(format="hex"): +ColorFormat = Literal["hex", "rgb", "lab"] + + +@overload +def pal_tinytopics(format: Literal["hex"]) -> List[str]: ... + + +@overload +def pal_tinytopics(format: Literal["rgb", "lab"]) -> NDArray[np.float64]: ... + + +def pal_tinytopics( + format: ColorFormat = "hex", +) -> Union[List[str], NDArray[np.float64]]: """ The tinytopics 10 color palette. @@ -15,19 +32,20 @@ def pal_tinytopics(format="hex"): especially when used in a context where color interpolation is needed. Args: - format (str, optional): - Returned color format. Options are: + format: Returned color format. Options are: `hex`: Hex strings (default). `rgb`: Array of RGB values. `lab`: Array of CIELAB values. Returns: - (list or np.ndarray): - - If `format='hex'`, returns a list of hex color strings. - - If `format='rgb'`, returns an Nx3 numpy array of RGB values. - - If `format='lab'`, returns an Nx3 numpy array of CIELAB values. + - If `format='hex'`, returns a list of hex color strings. + - If `format='rgb'`, returns an Nx3 numpy array of RGB values. + - If `format='lab'`, returns an Nx3 numpy array of CIELAB values. + + Raises: + ValueError: If format is not 'hex', 'rgb', or 'lab'. """ - tinytopics_10_colors_hex = [ + TINYTOPICS_10_COLORS: tuple[str, ...] = ( "#4269D0", # Blue "#EFB118", # Orange "#3CA951", # Green @@ -38,37 +56,39 @@ def pal_tinytopics(format="hex"): "#9498A0", # Gray "#6CC5B0", # Cyan "#97BBF5", # Light Blue - ] + ) if format == "hex": - return tinytopics_10_colors_hex - elif format == "rgb": - # Convert hex to RGB - return np.array([colors.to_rgb(color) for color in tinytopics_10_colors_hex]) + return list(TINYTOPICS_10_COLORS) + + # Convert hex to RGB array + rgb_colors: NDArray[np.float64] = np.array( + [colors.to_rgb(color) for color in TINYTOPICS_10_COLORS] + ) + + if format == "rgb": + return rgb_colors elif format == "lab": - # Convert hex to RGB, then to CIELAB - rgb_colors = np.array( - [colors.to_rgb(color) for color in tinytopics_10_colors_hex] - ) + # Convert RGB to CIELAB return color.rgb2lab(rgb_colors.reshape(1, -1, 3)).reshape(-1, 3) else: raise ValueError("Format must be 'hex', 'rgb', or 'lab'.") -def scale_color_tinytopics(n): +def scale_color_tinytopics(n: int) -> ListedColormap: """ A tinytopics 10 color scale. If > 10 colors are required, will generate an interpolated color palette based on the 10-color palette in the CIELAB color space using B-splines. Args: - n (int): The number of colors needed. + n: The number of colors needed. Returns: - (matplotlib.colors.ListedColormap): A colormap with n colors, possibly interpolated from the 10 colors. + A colormap with n colors, possibly interpolated from the 10 colors. """ - base_rgb_colors = pal_tinytopics(format="rgb") - base_lab_colors = pal_tinytopics(format="lab") + base_rgb_colors: NDArray[np.float64] = pal_tinytopics(format="rgb") + base_lab_colors: NDArray[np.float64] = pal_tinytopics(format="lab") # If interpolation is NOT needed, return the first n colors directly if n <= len(base_rgb_colors): @@ -76,16 +96,21 @@ def scale_color_tinytopics(n): # If interpolation is needed, interpolate in the CIELAB space # for perceptually uniform colors - additional_colors_needed = n - 10 + additional_colors_needed: int = n - 10 + # Original positions of the 10 base colors - x = np.linspace(0, 1, len(base_lab_colors)) + x: NDArray[np.float64] = np.linspace(0, 1, len(base_lab_colors)) + # B-spline interpolator in the CIELAB space bspline = make_interp_spline(x, base_lab_colors, k=3) + # Interpolated positions for new colors - x_new = np.linspace(0, 1, additional_colors_needed + 10) - interpolated_lab = bspline(x_new) + x_new: NDArray[np.float64] = np.linspace(0, 1, additional_colors_needed + 10) + interpolated_lab: NDArray[np.float64] = bspline(x_new) # Convert interpolated LAB colors back to RGB - interpolated_rgb = color.lab2rgb(interpolated_lab.reshape(1, -1, 3)).reshape(-1, 3) + interpolated_rgb: NDArray[np.float64] = color.lab2rgb( + interpolated_lab.reshape(1, -1, 3) + ).reshape(-1, 3) return colors.ListedColormap(interpolated_rgb) diff --git a/src/tinytopics/fit.py b/src/tinytopics/fit.py index f198207..1affeab 100644 --- a/src/tinytopics/fit.py +++ b/src/tinytopics/fit.py @@ -1,59 +1,81 @@ +from typing import List, Optional, Tuple + import torch +from torch import Tensor +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from tqdm import tqdm + from .models import NeuralPoissonNMF +def poisson_nmf_loss(X: Tensor, X_reconstructed: Tensor) -> Tensor: + """ + Compute the Poisson NMF loss function (negative log-likelihood). + + Args: + X: Original document-term matrix. + X_reconstructed: Reconstructed matrix from the model. + + Returns: + The computed Poisson NMF loss. + """ + epsilon: float = 1e-10 + return ( + X_reconstructed - X * torch.log(torch.clamp(X_reconstructed, min=epsilon)) + ).sum() + + def fit_model( - X, - k, - num_epochs=200, - batch_size=16, - base_lr=0.01, - max_lr=0.05, - T_0=20, - T_mult=1, - weight_decay=1e-5, - device=None, -): + X: Tensor, + k: int, + num_epochs: int = 200, + batch_size: int = 16, + base_lr: float = 0.01, + max_lr: float = 0.05, + T_0: int = 20, + T_mult: int = 1, + weight_decay: float = 1e-5, + device: Optional[torch.device] = None, +) -> Tuple[NeuralPoissonNMF, List[float]]: """ Fit topic model using sum-to-one constrained neural Poisson NMF, optimized with AdamW and a cosine annealing with warm restarts scheduler. Args: - X (torch.Tensor): Document-term matrix. - k (int): Number of topics. - num_epochs (int, optional): Number of training epochs. Default is 200. - batch_size (int, optional): Number of documents per batch. Default is 16. - base_lr (float, optional): Minimum learning rate after annealing. Default is 0.01. - max_lr (float, optional): Starting maximum learning rate. Default is 0.05. - T_0 (int, optional): Number of epochs until the first restart. Default is 20. - T_mult (int, optional): Factor by which the restart interval increases after each restart. Default is 1. - weight_decay (float, optional): Weight decay for the AdamW optimizer. Default is 1e-5. - device (torch.device, optional): Device to run the training on. Defaults to CUDA if available, otherwise CPU. + X: Document-term matrix. + k: Number of topics. + num_epochs: Number of training epochs. Default is 200. + batch_size: Number of documents per batch. Default is 16. + base_lr: Minimum learning rate after annealing. Default is 0.01. + max_lr: Starting maximum learning rate. Default is 0.05. + T_0: Number of epochs until the first restart. Default is 20. + T_mult: Factor by which the restart interval increases after each restart. Default is 1. + weight_decay: Weight decay for the AdamW optimizer. Default is 1e-5. + device: Device to run the training on. Defaults to CUDA if available, otherwise CPU. Returns: - (NeuralPoissonNMF): Trained model. - (list): List of training losses for each epoch. + A tuple containing: + - The trained NeuralPoissonNMF model + - List of training losses for each epoch """ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") X = X.to(device) n, m = X.shape - model = NeuralPoissonNMF(n, m, k, device=device) - optimizer = torch.optim.AdamW( - model.parameters(), lr=max_lr, weight_decay=weight_decay - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + model = NeuralPoissonNMF(n=n, m=m, k=k, device=device) + optimizer = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay) + scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=T_0, T_mult=T_mult, eta_min=base_lr ) - losses = [] + losses: List[float] = [] + num_batches: int = n // batch_size with tqdm(total=num_epochs, desc="Training Progress") as pbar: for epoch in range(num_epochs): permutation = torch.randperm(n, device=device) - epoch_loss = 0.0 - num_batches = n // batch_size + epoch_loss: float = 0.0 for i in range(num_batches): indices = permutation[i * batch_size : (i + 1) * batch_size] @@ -76,17 +98,3 @@ def fit_model( pbar.update(1) return model, losses - - -def poisson_nmf_loss(X, X_reconstructed): - """ - Compute the Poisson NMF loss function (negative log-likelihood). - - Args: - X (torch.Tensor): Original document-term matrix. - X_reconstructed (torch.Tensor): Reconstructed matrix from the model. - """ - epsilon = 1e-10 - return ( - X_reconstructed - X * torch.log(torch.clamp(X_reconstructed, min=epsilon)) - ).sum() diff --git a/src/tinytopics/models.py b/src/tinytopics/models.py index 3346dfa..153308b 100644 --- a/src/tinytopics/models.py +++ b/src/tinytopics/models.py @@ -1,69 +1,74 @@ +from typing import Optional + import torch import torch.nn as nn +from torch import Tensor class NeuralPoissonNMF(nn.Module): - def __init__(self, n, m, k, device=None): + def __init__( + self, n: int, m: int, k: int, device: Optional[torch.device] = None + ) -> None: """ Neural Poisson NMF model with sum-to-one constraints on document-topic and topic-term distributions. Args: - n (int): Number of documents. - m (int): Number of terms (vocabulary size). - k (int): Number of topics. - device (torch.device, optional): Device to run the model on. Defaults to CPU. + n: Number of documents. + m: Number of terms (vocabulary size). + k: Number of topics. + device: Device to run the model on. Defaults to CPU. """ super(NeuralPoissonNMF, self).__init__() - self.device = device or torch.device("cpu") + self.device: torch.device = device or torch.device("cpu") # Use embedding for L to handle batches efficiently - self.L = nn.Embedding(n, k).to(self.device) + self.L: nn.Embedding = nn.Embedding(n, k).to(self.device) # Initialize L with small positive values nn.init.uniform_(self.L.weight, a=0.0, b=0.1) # Define F as a parameter and initialize with small positive values - self.F = nn.Parameter(torch.empty(k, m, device=self.device)) + self.F: nn.Parameter = nn.Parameter(torch.empty(k, m, device=self.device)) nn.init.uniform_(self.F, a=0.0, b=0.1) - def forward(self, doc_indices): + def forward(self, doc_indices: Tensor) -> Tensor: """ Forward pass of the neural Poisson NMF model. Args: - doc_indices (torch.Tensor): Indices of documents in the batch. + doc_indices: Indices of documents in the batch. Returns: - (torch.Tensor): Reconstructed document-term matrix for the batch. + Reconstructed document-term matrix for the batch. """ # Get the L vectors for the batch - L_batch = self.L(doc_indices) + L_batch: Tensor = self.L(doc_indices) # Sum-to-one constraints across topics for each document - L_normalized = torch.softmax(L_batch, dim=1) + L_normalized: Tensor = torch.softmax(L_batch, dim=1) # Sum-to-one constraints across terms for each topic - F_normalized = torch.softmax(self.F, dim=1) + F_normalized: Tensor = torch.softmax(self.F, dim=1) # Return the matrix product to approximate X_batch return torch.matmul(L_normalized, F_normalized) - def get_normalized_L(self): + def get_normalized_L(self) -> Tensor: """ Get the learned, normalized document-topic distribution matrix (L). Returns: - (torch.Tensor): Normalized L matrix on the CPU. + Normalized L matrix on the CPU. """ with torch.no_grad(): return torch.softmax(self.L.weight, dim=1).cpu() - def get_normalized_F(self): + def get_normalized_F(self) -> Tensor: """ Get the learned, normalized topic-term distribution matrix (F). Returns: - (torch.Tensor): Normalized F matrix on the CPU. + Normalized F matrix on the CPU. """ with torch.no_grad(): return torch.softmax(self.F, dim=1).cpu() diff --git a/src/tinytopics/plot.py b/src/tinytopics/plot.py index c8a9282..3b8baea 100644 --- a/src/tinytopics/plot.py +++ b/src/tinytopics/plot.py @@ -1,35 +1,40 @@ +from typing import List, Optional, Tuple, Union + import numpy as np import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes + from .colors import scale_color_tinytopics def plot_loss( - losses, - figsize=(10, 8), - dpi=300, - title="Loss curve", - color_palette=None, - output_file=None, -): + losses: List[float], + figsize: Tuple[int, int] = (10, 8), + dpi: int = 300, + title: str = "Loss curve", + color_palette: Optional[Union[List[str], str]] = None, + output_file: Optional[str] = None, +) -> None: """ Plot the loss curve over training epochs. Args: - losses (list): List of loss values for each epoch. - figsize (tuple, optional): Plot size. Default is (10, 8). - dpi (int, optional): Plot resolution. Default is 300. - title (str, optional): Plot title. Default is "Loss curve". - color_palette (list or matplotlib colormap, optional): Custom color palette. - output_file (str, optional): File path to save the plot. If None, displays the plot. + losses: List of loss values for each epoch. + figsize: Plot size. Default is (10, 8). + dpi: Plot resolution. Default is 300. + title: Plot title. Default is "Loss curve". + color_palette: Custom color palette. + output_file: File path to save the plot. If None, displays the plot. """ - if color_palette is None: - color_palette = scale_color_tinytopics(1) + palette = scale_color_tinytopics(1) if color_palette is None else color_palette - plt.figure(figsize=figsize, dpi=dpi) - plt.plot(losses, color=color_palette(0)) + fig = plt.figure(figsize=figsize, dpi=dpi) + plt.plot(losses, color=palette(0)) plt.title(title) plt.xlabel("Epochs") plt.ylabel("Loss") + if output_file: plt.savefig(output_file, dpi=dpi) plt.close() @@ -38,46 +43,47 @@ def plot_loss( def plot_structure( - L_matrix, - normalize_rows=False, - figsize=(12, 6), - dpi=300, - title="Structure Plot", - color_palette=None, - output_file=None, -): + L_matrix: np.ndarray, + normalize_rows: bool = False, + figsize: Tuple[int, int] = (12, 6), + dpi: int = 300, + title: str = "Structure Plot", + color_palette: Optional[Union[List[str], str]] = None, + output_file: Optional[str] = None, +) -> None: """ Structure plot for visualizing document-topic distributions. Args: - L_matrix (np.ndarray): Document-topic distribution matrix. - normalize_rows (bool, optional): If True, normalizes each row of L_matrix to sum to 1. - figsize (tuple, optional): Plot size. Default is (12, 6). - dpi (int, optional): Plot resolution. Default is 300. - title (str): Plot title. - color_palette (list or matplotlib colormap, optional): Custom color palette. - output_file (str, optional): File path to save the plot. If None, displays the plot. + L_matrix: Document-topic distribution matrix. + normalize_rows: If True, normalizes each row of L_matrix to sum to 1. + figsize: Plot size. Default is (12, 6). + dpi: Plot resolution. Default is 300. + title: Plot title. + color_palette: Custom color palette. + output_file: File path to save the plot. If None, displays the plot. """ - if normalize_rows: - L_matrix = L_matrix / L_matrix.sum(axis=1, keepdims=True) - - n_documents, n_topics = L_matrix.shape - ind = np.arange(n_documents) # Document indices + matrix = ( + L_matrix / L_matrix.sum(axis=1, keepdims=True) if normalize_rows else L_matrix + ) + n_documents, n_topics = matrix.shape + ind = np.arange(n_documents) cumulative = np.zeros(n_documents) - - if color_palette is None: - color_palette = scale_color_tinytopics(n_topics) + palette = ( + scale_color_tinytopics(n_topics) if color_palette is None else color_palette + ) plt.figure(figsize=figsize, dpi=dpi) for k in range(n_topics): plt.bar( ind, - L_matrix[:, k], + matrix[:, k], bottom=cumulative, - color=color_palette(k), + color=palette(k), width=1.0, ) - cumulative += L_matrix[:, k] + cumulative += matrix[:, k] + plt.title(title) plt.xlabel("Documents (sorted)") plt.ylabel("Topic Proportions") @@ -93,46 +99,45 @@ def plot_structure( def plot_top_terms( - F_matrix, - n_top_terms=10, - term_names=None, - figsize=(10, 8), - dpi=300, - title="Top Terms", - color_palette=None, - nrows=None, - ncols=None, - output_file=None, -): + F_matrix: np.ndarray, + n_top_terms: int = 10, + term_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (10, 8), + dpi: int = 300, + title: str = "Top Terms", + color_palette: Optional[Union[List[str], str]] = None, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + output_file: Optional[str] = None, +) -> None: """ Plot top terms for each topic in horizontal bar charts. Args: - F_matrix (np.ndarray): Topic-term distribution matrix. - n_top_terms (int, optional): Number of top terms to display per topic. Default is 10. - term_names (list, optional): List of term names corresponding to indices. - figsize (tuple, optional): Plot size. Default is (10, 8). - dpi (int, optional): Plot resolution. Default is 300. - title (str): Plot title. - color_palette (list or matplotlib colormap, optional): Custom color palette. - nrows (int, optional): Number of rows in the subplot grid. - ncols (int, optional): Number of columns in the subplot grid. - output_file (str, optional): File path to save the plot. If None, displays the plot. + F_matrix: Topic-term distribution matrix. + n_top_terms: Number of top terms to display per topic. Default is 10. + term_names: List of term names corresponding to indices. + figsize: Plot size. Default is (10, 8). + dpi: Plot resolution. Default is 300. + title: Plot title. + color_palette: Custom color palette. + nrows: Number of rows in the subplot grid. + ncols: Number of columns in the subplot grid. + output_file: File path to save the plot. If None, displays the plot. """ n_topics = F_matrix.shape[0] top_terms_indices = np.argsort(-F_matrix, axis=1)[:, :n_top_terms] top_terms_probs = np.take_along_axis(F_matrix, top_terms_indices, axis=1) + top_terms_labels = ( + np.array(term_names)[top_terms_indices] + if term_names is not None + else top_terms_indices.astype(str) + ) + palette = ( + scale_color_tinytopics(n_topics) if color_palette is None else color_palette + ) - # Use term names if provided - if term_names is not None: - top_terms_labels = np.array(term_names)[top_terms_indices] - else: - top_terms_labels = top_terms_indices.astype(str) - - if color_palette is None: - color_palette = scale_color_tinytopics(n_topics) - - # Grid layout + # Calculate grid dimensions if nrows is None and ncols is None: ncols = 5 nrows = int(np.ceil(n_topics / ncols)) @@ -144,25 +149,26 @@ def plot_top_terms( fig, axes = plt.subplots( nrows, ncols, figsize=figsize, dpi=dpi, constrained_layout=True ) - axes = axes.flatten() - - for i in range(n_topics): - ax = axes[i] - # Get data for topic i - probs = top_terms_probs[i] - labels = top_terms_labels[i] + axes_flat = axes.flatten() - # Place highest probability terms at the top + def plot_topic(ax: Axes, topic_idx: int) -> None: + probs = top_terms_probs[topic_idx] + labels = top_terms_labels[topic_idx] y_pos = np.arange(n_top_terms)[::-1] - ax.barh(y_pos, probs, color=color_palette(i)) + + ax.barh(y_pos, probs, color=palette(topic_idx)) ax.set_yticks(y_pos) ax.set_yticklabels(labels) ax.set_xlabel("Probability") - ax.set_title(f"Topic {i}") + ax.set_title(f"Topic {topic_idx}") ax.set_xlim(0, top_terms_probs.max() * 1.1) + + for i in range(n_topics): + plot_topic(axes_flat[i], i) + # Hide unused subplots - for j in range(n_topics, len(axes)): - fig.delaxes(axes[j]) + for j in range(n_topics, len(axes_flat)): + fig.delaxes(axes_flat[j]) fig.suptitle(title) diff --git a/src/tinytopics/utils.py b/src/tinytopics/utils.py index 6f94593..6a0d2ae 100644 --- a/src/tinytopics/utils.py +++ b/src/tinytopics/utils.py @@ -1,10 +1,13 @@ +from typing import Optional, Tuple, List +from collections import defaultdict + import torch import numpy as np -from tqdm import tqdm from scipy.optimize import linear_sum_assignment +from tqdm import tqdm -def set_random_seed(seed): +def set_random_seed(seed: int) -> None: """ Set the random seed for reproducibility across Torch and NumPy. @@ -17,7 +20,13 @@ def set_random_seed(seed): torch.cuda.manual_seed_all(seed) -def generate_synthetic_data(n, m, k, avg_doc_length=1000, device=None): +def generate_synthetic_data( + n: int, + m: int, + k: int, + avg_doc_length: int = 1000, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, np.ndarray, np.ndarray]: """ Generate synthetic document-term matrix for testing the model. @@ -47,30 +56,28 @@ def generate_synthetic_data(n, m, k, avg_doc_length=1000, device=None): # Initialize document-term matrix X X = np.zeros((n, m), dtype=np.int32) - for i in tqdm(range(n), desc="Generating Documents"): - # Sample topic counts for document i + def generate_document(i: int, doc_length: int) -> np.ndarray: topic_probs = true_L[i] - topic_counts = np.random.multinomial(doc_lengths[i], topic_probs) + topic_counts = np.random.multinomial(doc_length, topic_probs) - # Initialize term counts for document i - term_counts = np.zeros(m, dtype=np.int32) + def sample_terms_for_topic(j: int, count: int) -> np.ndarray: + if count == 0: + return np.zeros(m, dtype=np.int32) + term_probs = true_F[j] + return np.random.multinomial(count, term_probs) - # For each topic j - for j in range(k): - if topic_counts[j] > 0: - # Sample term counts for topic j - term_probs = true_F[j] - term_counts_j = np.random.multinomial(topic_counts[j], term_probs) - # Add term counts to document i - term_counts += term_counts_j + term_counts = sum( + sample_terms_for_topic(j, count) for j, count in enumerate(topic_counts) + ) + return term_counts - # Assign term counts to X[i,:] - X[i, :] = term_counts + for i in tqdm(range(n), desc="Generating Documents"): + X[i, :] = generate_document(i, doc_lengths[i]) return torch.tensor(X, device=device, dtype=torch.float32), true_L, true_F -def align_topics(true_F, learned_F): +def align_topics(true_F: np.ndarray, learned_F: np.ndarray) -> np.ndarray: """ Align learned topics with true topics for visualization, using cosine similarity and linear sum assignment. @@ -82,21 +89,21 @@ def align_topics(true_F, learned_F): Returns: (np.ndarray): Permutation of learned topics aligned with true topics. """ - # Normalize topic-term distributions - true_F_norm = true_F / np.linalg.norm(true_F, axis=1, keepdims=True) - learned_F_norm = learned_F / np.linalg.norm(learned_F, axis=1, keepdims=True) - # Compute the cosine similarity matrix + def normalize_matrix(matrix: np.ndarray) -> np.ndarray: + return matrix / np.linalg.norm(matrix, axis=1, keepdims=True) + + true_F_norm = normalize_matrix(true_F) + learned_F_norm = normalize_matrix(learned_F) + similarity_matrix = np.dot(true_F_norm, learned_F_norm.T) - # Compute the cost matrix for assignment (use negative similarity) cost_matrix = -similarity_matrix - # Solve the assignment problem - row_ind, col_ind = linear_sum_assignment(cost_matrix) + _, col_ind = linear_sum_assignment(cost_matrix) return col_ind -def sort_documents(L_matrix): +def sort_documents(L_matrix: np.ndarray) -> List[int]: """ Sort documents grouped by dominant topics for visualization. @@ -107,29 +114,27 @@ def sort_documents(L_matrix): (list): Indices of documents sorted by dominant topics. """ n, k = L_matrix.shape - # Normalize L L_normalized = L_matrix / L_matrix.sum(axis=1, keepdims=True) - # Determine dominant topics and their proportions - dominant_topics = np.argmax(L_normalized, axis=1) - dominant_props = L_normalized[np.arange(n), dominant_topics] - - # Combine indices, dominant topics, and proportions - doc_info = list(zip(np.arange(n), dominant_topics, dominant_props)) - - # Group documents by dominant topic - from collections import defaultdict - - grouped_docs = defaultdict(list) - for idx, topic, prop in doc_info: - grouped_docs[topic].append((idx, prop)) - - # Sort documents within each group by proportion of the dominant topic - sorted_indices = [] - for topic in range(k): - docs_in_topic = grouped_docs.get(topic, []) - # Sort by proportion in descending order - docs_sorted = sorted(docs_in_topic, key=lambda x: x[1], reverse=True) - sorted_indices.extend([idx for idx, _ in docs_sorted]) - - return sorted_indices + def get_document_info() -> List[Tuple[int, int, float]]: + dominant_topics = np.argmax(L_normalized, axis=1) + dominant_props = L_normalized[np.arange(n), dominant_topics] + return list(zip(range(n), dominant_topics, dominant_props)) + + def group_by_topic(doc_info: List[Tuple[int, int, float]]) -> defaultdict: + groups: defaultdict = defaultdict(list) + for idx, topic, prop in doc_info: + groups[topic].append((idx, prop)) + return groups + + def sort_topic_groups(grouped_docs: defaultdict) -> List[int]: + sorted_indices = [] + for topic in range(k): + docs_in_topic = grouped_docs.get(topic, []) + docs_sorted = sorted(docs_in_topic, key=lambda x: x[1], reverse=True) + sorted_indices.extend(idx for idx, _ in docs_sorted) + return sorted_indices + + doc_info = get_document_info() + grouped_docs = group_by_topic(doc_info) + return sort_topic_groups(grouped_docs) From d5b74c33cd32ffb31e901f412a513be06897b7f0 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 16:31:12 -0500 Subject: [PATCH 3/6] Use groups in import section --- docs/articles/benchmark.md | 1 + docs/articles/benchmark.qmd | 1 + docs/articles/text.md | 3 ++- docs/articles/text.qmd | 3 ++- examples/benchmark.py | 1 + examples/text.py | 3 ++- 6 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/articles/benchmark.md b/docs/articles/benchmark.md index c5452df..079c8c0 100644 --- a/docs/articles/benchmark.md +++ b/docs/articles/benchmark.md @@ -42,6 +42,7 @@ import time import torch import pandas as pd import matplotlib.pyplot as plt + from tinytopics.fit import fit_model from tinytopics.utils import generate_synthetic_data, set_random_seed from tinytopics.colors import scale_color_tinytopics diff --git a/docs/articles/benchmark.qmd b/docs/articles/benchmark.qmd index 9a1faaf..826606d 100644 --- a/docs/articles/benchmark.qmd +++ b/docs/articles/benchmark.qmd @@ -44,6 +44,7 @@ import time import torch import pandas as pd import matplotlib.pyplot as plt + from tinytopics.fit import fit_model from tinytopics.utils import generate_synthetic_data, set_random_seed from tinytopics.colors import scale_color_tinytopics diff --git a/docs/articles/text.md b/docs/articles/text.md index abdf26e..7e90cd8 100644 --- a/docs/articles/text.md +++ b/docs/articles/text.md @@ -22,10 +22,11 @@ repo](https://github.com/stephenslab/fastTopics-experiments). ## Import tinytopics ``` python +import torch import numpy as np import pandas as pd -import torch from pyreadr import read_r + from tinytopics.fit import fit_model from tinytopics.plot import plot_loss, plot_structure, plot_top_terms from tinytopics.utils import ( diff --git a/docs/articles/text.qmd b/docs/articles/text.qmd index a47159c..e2b537b 100644 --- a/docs/articles/text.qmd +++ b/docs/articles/text.qmd @@ -25,10 +25,11 @@ The NIPS dataset contains a count matrix for 2483 research papers on ## Import tinytopics ```{python} +import torch import numpy as np import pandas as pd -import torch from pyreadr import read_r + from tinytopics.fit import fit_model from tinytopics.plot import plot_loss, plot_structure, plot_top_terms from tinytopics.utils import ( diff --git a/examples/benchmark.py b/examples/benchmark.py index 6210816..04ff5c1 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -2,6 +2,7 @@ import torch import pandas as pd import matplotlib.pyplot as plt + from tinytopics.fit import fit_model from tinytopics.utils import generate_synthetic_data, set_random_seed from tinytopics.colors import scale_color_tinytopics diff --git a/examples/text.py b/examples/text.py index d7b04ce..c012277 100644 --- a/examples/text.py +++ b/examples/text.py @@ -1,7 +1,8 @@ +import torch import numpy as np import pandas as pd -import torch from pyreadr import read_r + from tinytopics.fit import fit_model from tinytopics.plot import plot_loss, plot_structure, plot_top_terms from tinytopics.utils import ( From fb9363c2e3e45514b52a65f72938ff0501ae7b40 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 16:55:19 -0500 Subject: [PATCH 4/6] Add group to import section --- docs/articles/benchmark.md | 1 + docs/articles/benchmark.qmd | 1 + examples/benchmark.py | 1 + 3 files changed, 3 insertions(+) diff --git a/docs/articles/benchmark.md b/docs/articles/benchmark.md index 079c8c0..f96fa44 100644 --- a/docs/articles/benchmark.md +++ b/docs/articles/benchmark.md @@ -39,6 +39,7 @@ Experiment environment: ``` python import time + import torch import pandas as pd import matplotlib.pyplot as plt diff --git a/docs/articles/benchmark.qmd b/docs/articles/benchmark.qmd index 826606d..800bbb0 100644 --- a/docs/articles/benchmark.qmd +++ b/docs/articles/benchmark.qmd @@ -41,6 +41,7 @@ Experiment environment: ```{python} import time + import torch import pandas as pd import matplotlib.pyplot as plt diff --git a/examples/benchmark.py b/examples/benchmark.py index 04ff5c1..f9ad872 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,4 +1,5 @@ import time + import torch import pandas as pd import matplotlib.pyplot as plt From 27fcc3fea56d9aeb05fcdcf1946af24f14a3e4f1 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 17:47:01 -0500 Subject: [PATCH 5/6] Increment version number to 0.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 77c2d3c..2bff0af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinytopics" -version = "0.2.0" +version = "0.3.0" description = "Topic modeling via sum-to-one constrained neural Poisson non-negative matrix factorization" authors = [ { name = "Nan Xiao", email = "me@nanx.me" } From a883e6d551fa8e62e04f8df33f7fb85a718f8901 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Mon, 11 Nov 2024 17:53:55 -0500 Subject: [PATCH 6/6] Update news for v0.3.0 --- CHANGELOG.md | 7 +++++++ docs/changelog.md | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4736ec..a208f07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## tinytopics 0.3.0 + +### Improvements + +- Refactor the code to use a more functional style and add type hints + to improve code clarity (#9). + ## tinytopics 0.2.0 ### New features diff --git a/docs/changelog.md b/docs/changelog.md index c4736ec..a208f07 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,12 @@ # Changelog +## tinytopics 0.3.0 + +### Improvements + +- Refactor the code to use a more functional style and add type hints + to improve code clarity (#9). + ## tinytopics 0.2.0 ### New features