-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from nanxstats/distributed
Place dataset and loss into separate modules
- Loading branch information
Showing
14 changed files
with
232 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Data | ||
|
||
::: tinytopics.data | ||
options: | ||
members: | ||
- NumpyDiskDataset | ||
- IndexTrackingDataset | ||
show_root_heading: true | ||
show_source: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,5 @@ | |
options: | ||
members: | ||
- fit_model | ||
- poisson_nmf_loss | ||
show_root_heading: true | ||
show_source: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Losses | ||
|
||
::: tinytopics.loss | ||
options: | ||
members: | ||
- poisson_nmf_loss | ||
show_root_heading: true | ||
show_source: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,16 @@ | ||
""" | ||
Topic modeling via sum-to-one constrained neural Poisson NMF. | ||
Modules: | ||
fit: Model fitting and loss calculation. | ||
models: NeuralPoissonNMF model definition. | ||
plot: Functions for plotting loss curves, document-topic distributions, and top terms. | ||
colors: Color palettes. | ||
utils: Utility functions for data generation, topic alignment, and document sorting. | ||
""" | ||
|
||
from .fit import fit_model | ||
from .models import NeuralPoissonNMF | ||
from .fit import fit_model, poisson_nmf_loss | ||
from .loss import poisson_nmf_loss | ||
from .data import NumpyDiskDataset | ||
from .utils import ( | ||
set_random_seed, | ||
generate_synthetic_data, | ||
align_topics, | ||
sort_documents, | ||
NumpyDiskDataset, | ||
) | ||
from .colors import pal_tinytopics, scale_color_tinytopics | ||
from .plot import plot_loss, plot_structure, plot_top_terms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from collections.abc import Sequence | ||
from pathlib import Path | ||
|
||
import torch | ||
import numpy as np | ||
from torch import Tensor | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class IndexTrackingDataset(Dataset): | ||
"""Dataset wrapper that tracks indices through shuffling""" | ||
|
||
def __init__(self, dataset: Dataset | Tensor) -> None: | ||
self.dataset = dataset | ||
self.shape: tuple[int, int] = ( | ||
dataset.shape | ||
if hasattr(dataset, "shape") | ||
else (len(dataset), dataset[0].shape[0]) | ||
) | ||
self.is_tensor: bool = isinstance(dataset, torch.Tensor) | ||
|
||
def __len__(self) -> int: | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: | ||
return self.dataset[idx], torch.tensor(idx) | ||
|
||
|
||
class NumpyDiskDataset(Dataset): | ||
""" | ||
A PyTorch Dataset class for loading document-term matrices from disk. | ||
The dataset can be initialized with either a path to a `.npy` file or | ||
a NumPy array. When a file path is provided, the data is accessed | ||
lazily using memory mapping, which is useful for handling large datasets | ||
that do not fit entirely in (CPU) memory. | ||
""" | ||
|
||
def __init__( | ||
self, data: str | Path | np.ndarray, indices: Sequence[int] | None = None | ||
) -> None: | ||
""" | ||
Args: | ||
data: Either path to `.npy` file (str or Path) or numpy array. | ||
indices: Optional sequence of indices to use as valid indices. | ||
""" | ||
if isinstance(data, (str, Path)): | ||
data_path = Path(data) | ||
if not data_path.exists(): | ||
raise FileNotFoundError(f"Data file not found: {data_path}") | ||
# Get shape without loading full array | ||
self.shape: tuple[int, int] = tuple(np.load(data_path, mmap_mode="r").shape) | ||
self.data_path: Path = data_path | ||
self.mmap_data: np.ndarray | None = None | ||
else: | ||
self.shape: tuple[int, int] = data.shape | ||
self.data_path: None = None | ||
self.data: np.ndarray = data | ||
|
||
self.indices: Sequence[int] = indices or range(self.shape[0]) | ||
|
||
def __len__(self) -> int: | ||
return len(self.indices) | ||
|
||
def __getitem__(self, idx: int) -> torch.Tensor: | ||
real_idx = self.indices[idx] | ||
|
||
if self.data_path is not None: | ||
# Load mmap data lazily | ||
if self.mmap_data is None: | ||
self.mmap_data = np.load(self.data_path, mmap_mode="r") | ||
return torch.tensor(self.mmap_data[real_idx], dtype=torch.float32) | ||
else: | ||
return torch.tensor(self.data[real_idx], dtype=torch.float32) | ||
|
||
@property | ||
def num_terms(self) -> int: | ||
"""Return vocabulary size (number of columns).""" | ||
return self.shape[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def poisson_nmf_loss(X: Tensor, X_reconstructed: Tensor) -> Tensor: | ||
""" | ||
Compute the Poisson NMF loss function (negative log-likelihood). | ||
Args: | ||
X: Original document-term matrix. | ||
X_reconstructed: Reconstructed matrix from the model. | ||
Returns: | ||
The computed Poisson NMF loss. | ||
""" | ||
epsilon: float = 1e-10 | ||
return ( | ||
X_reconstructed - X * torch.log(torch.clamp(X_reconstructed, min=epsilon)) | ||
).sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import pytest | ||
import torch | ||
import numpy as np | ||
|
||
from tinytopics.data import NumpyDiskDataset | ||
|
||
|
||
def test_numpy_disk_dataset_from_array(): | ||
"""Test NumpyDiskDataset with direct numpy array input.""" | ||
data = np.random.rand(10, 5).astype(np.float32) | ||
|
||
dataset = NumpyDiskDataset(data) | ||
|
||
# Test basic properties | ||
assert len(dataset) == 10 | ||
assert dataset.num_terms == 5 | ||
assert dataset.shape == (10, 5) | ||
|
||
# Test data access | ||
for i in range(len(dataset)): | ||
item = dataset[i] | ||
assert isinstance(item, torch.Tensor) | ||
assert item.shape == (5,) | ||
assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32)) | ||
|
||
|
||
def test_numpy_disk_dataset_from_file(tmp_path): | ||
"""Test NumpyDiskDataset with .npy file input.""" | ||
data = np.random.rand(10, 5).astype(np.float32) | ||
file_path = tmp_path / "test_data.npy" | ||
np.save(file_path, data) | ||
|
||
dataset = NumpyDiskDataset(file_path) | ||
|
||
# Test basic properties | ||
assert len(dataset) == 10 | ||
assert dataset.num_terms == 5 | ||
assert dataset.shape == (10, 5) | ||
|
||
# Test data access | ||
for i in range(len(dataset)): | ||
item = dataset[i] | ||
assert isinstance(item, torch.Tensor) | ||
assert item.shape == (5,) | ||
assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32)) | ||
|
||
|
||
def test_numpy_disk_dataset_with_indices(): | ||
"""Test NumpyDiskDataset with custom indices.""" | ||
data = np.random.rand(10, 5).astype(np.float32) | ||
indices = [3, 1, 4] | ||
|
||
dataset = NumpyDiskDataset(data, indices=indices) | ||
|
||
# Test basic properties | ||
assert len(dataset) == len(indices) | ||
assert dataset.num_terms == 5 | ||
assert dataset.shape == (10, 5) | ||
|
||
# Test data access | ||
for i, orig_idx in enumerate(indices): | ||
item = dataset[i] | ||
assert isinstance(item, torch.Tensor) | ||
assert item.shape == (5,) | ||
assert torch.allclose(item, torch.tensor(data[orig_idx], dtype=torch.float32)) | ||
|
||
|
||
def test_numpy_disk_dataset_file_not_found(): | ||
"""Test NumpyDiskDataset with non-existent file.""" | ||
with pytest.raises(FileNotFoundError): | ||
NumpyDiskDataset("non_existent_file.npy") | ||
|
||
|
||
def test_numpy_disk_dataset_memory_efficiency(tmp_path): | ||
"""Test that NumpyDiskDataset uses memory mapping efficiently.""" | ||
shape = (1000, 500) # 500K elements | ||
data = np.random.rand(*shape).astype(np.float32) | ||
file_path = tmp_path / "large_data.npy" | ||
np.save(file_path, data) | ||
|
||
dataset = NumpyDiskDataset(file_path) | ||
|
||
# Access data in random order | ||
indices = np.random.permutation(shape[0])[:100] # Sample 100 random rows | ||
for idx in indices: | ||
item = dataset[idx] | ||
assert torch.allclose(item, torch.tensor(data[idx], dtype=torch.float32)) | ||
|
||
# Memory mapping should be initialized only after first access | ||
assert dataset.mmap_data is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.