Skip to content

Commit

Permalink
Merge pull request #5 from nanxstats/colormap
Browse files Browse the repository at this point in the history
Add custom color scale that supports arbitrary number of topics
  • Loading branch information
nanxstats authored Oct 26, 2024
2 parents 622b19a + f78ec5c commit b25e745
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 18 deletions.
9 changes: 9 additions & 0 deletions docs/reference/colors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Colors

::: tinytopics.colors
options:
members:
- pal_tinytopics
- scale_color_tinytopics
show_root_heading: true
show_source: false
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ nav:
- Fit: reference/fit.md
- Models: reference/models.md
- Plot: reference/plot.md
- Colors: reference/colors.md
- Utilities: reference/utils.md
- Changelog: changelog.md

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"numpy>=2.0.0",
"scipy>=1.13.0",
"matplotlib>=3.8.4",
"scikit-image>=0.22.0",
"tqdm>=4.65.0",
]
readme = "README.md"
Expand Down
17 changes: 17 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ idna==3.10
# via httpx
# via jsonschema
# via requests
imageio==2.36.0
# via scikit-image
ipykernel==6.29.5
# via jupyter
# via jupyter-console
Expand Down Expand Up @@ -163,6 +165,8 @@ jupyterlab-widgets==3.0.13
# via ipywidgets
kiwisolver==1.4.7
# via matplotlib
lazy-loader==0.4
# via scikit-image
markdown==3.7
# via mkdocs
# via mkdocs-autorefs
Expand Down Expand Up @@ -214,6 +218,7 @@ nbformat==5.10.4
nest-asyncio==1.6.0
# via ipykernel
networkx==3.4.2
# via scikit-image
# via torch
notebook==7.2.2
# via jupyter
Expand All @@ -222,8 +227,11 @@ notebook-shim==0.2.4
# via notebook
numpy==2.1.2
# via contourpy
# via imageio
# via matplotlib
# via scikit-image
# via scipy
# via tifffile
# via tinytopics
overrides==7.7.0
# via jupyter-server
Expand All @@ -232,9 +240,11 @@ packaging==24.1
# via jupyter-server
# via jupyterlab
# via jupyterlab-server
# via lazy-loader
# via matplotlib
# via mkdocs
# via nbconvert
# via scikit-image
paginate==0.5.7
# via mkdocs-material
pandocfilters==1.5.1
Expand All @@ -246,7 +256,9 @@ pathspec==0.12.1
pexpect==4.9.0
# via ipython
pillow==11.0.0
# via imageio
# via matplotlib
# via scikit-image
platformdirs==4.3.6
# via jupyter-core
# via mkdocs-get-deps
Expand Down Expand Up @@ -314,7 +326,10 @@ rpds-py==0.20.0
# via jsonschema
# via referencing
ruff==0.7.0
scikit-image==0.24.0
# via tinytopics
scipy==1.14.1
# via scikit-image
# via tinytopics
send2trash==1.8.3
# via jupyter-server
Expand All @@ -338,6 +353,8 @@ sympy==1.13.1
terminado==0.18.1
# via jupyter-server
# via jupyter-server-terminals
tifffile==2024.9.20
# via scikit-image
tinycss2==1.3.0
# via nbconvert
torch==2.5.0
Expand Down
17 changes: 17 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,56 @@ fonttools==4.54.1
# via matplotlib
fsspec==2024.10.0
# via torch
imageio==2.36.0
# via scikit-image
jinja2==3.1.4
# via torch
kiwisolver==1.4.7
# via matplotlib
lazy-loader==0.4
# via scikit-image
markupsafe==3.0.2
# via jinja2
matplotlib==3.9.2
# via tinytopics
mpmath==1.3.0
# via sympy
networkx==3.4.2
# via scikit-image
# via torch
numpy==2.1.2
# via contourpy
# via imageio
# via matplotlib
# via scikit-image
# via scipy
# via tifffile
# via tinytopics
packaging==24.1
# via lazy-loader
# via matplotlib
# via scikit-image
pillow==11.0.0
# via imageio
# via matplotlib
# via scikit-image
pyparsing==3.2.0
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
scikit-image==0.24.0
# via tinytopics
scipy==1.14.1
# via scikit-image
# via tinytopics
setuptools==75.2.0
# via torch
six==1.16.0
# via python-dateutil
sympy==1.13.1
# via torch
tifffile==2024.9.20
# via scikit-image
torch==2.5.0
# via tinytopics
tqdm==4.66.5
Expand Down
92 changes: 92 additions & 0 deletions src/tinytopics/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
from matplotlib import colors
from skimage import color
from scipy.interpolate import make_interp_spline


def pal_tinytopics(format="hex"):
"""
The tinytopics 10 color palette.
A rearranged version of the original Observable 10 palette.
Reordered to align with the color ordering of the D3 Category 10 palette,
also known as `matplotlib.cm.tab10`.
The rearrangement aims to improve perceptual familiarity and color harmony,
especially when used in a context where color interpolation is needed.
Args:
format (str, optional):
Returned color format. Options are:
`hex`: Hex strings (default).
`rgb`: Array of RGB values.
`lab`: Array of CIELAB values.
Returns:
(list or np.ndarray):
- If `format='hex'`, returns a list of hex color strings.
- If `format='rgb'`, returns an Nx3 numpy array of RGB values.
- If `format='lab'`, returns an Nx3 numpy array of CIELAB values.
"""
tinytopics_10_colors_hex = [
"#4269D0", # Blue
"#EFB118", # Orange
"#3CA951", # Green
"#FF725C", # Red
"#A463F2", # Purple
"#9C6B4E", # Brown
"#FF8AB7", # Pink
"#9498A0", # Gray
"#6CC5B0", # Cyan
"#97BBF5", # Light Blue
]

if format == "hex":
return tinytopics_10_colors_hex
elif format == "rgb":
# Convert hex to RGB
return np.array([colors.to_rgb(color) for color in tinytopics_10_colors_hex])
elif format == "lab":
# Convert hex to RGB, then to CIELAB
rgb_colors = np.array(
[colors.to_rgb(color) for color in tinytopics_10_colors_hex]
)
return color.rgb2lab(rgb_colors.reshape(1, -1, 3)).reshape(-1, 3)
else:
raise ValueError("Format must be 'hex', 'rgb', or 'lab'.")


def scale_color_tinytopics(n):
"""
A tinytopics 10 color scale. If > 10 colors are required, will generate
an interpolated color palette based on the 10-color palette in the CIELAB
color space using B-splines.
Args:
n (int): The number of colors needed.
Returns:
(matplotlib.colors.ListedColormap): A colormap with n colors,
possibly interpolated from the 10 colors.
"""
base_rgb_colors = pal_tinytopics(format="rgb")
base_lab_colors = pal_tinytopics(format="lab")

# If interpolation is NOT needed, return the first n colors directly
if n <= len(base_rgb_colors):
return colors.ListedColormap(base_rgb_colors[:n])

# If interpolation is needed, interpolate in the CIELAB space
# for perceptually uniform colors
additional_colors_needed = n - 10
# Original positions of the 10 base colors
x = np.linspace(0, 1, len(base_lab_colors))
# B-spline interpolator in the CIELAB space
bspline = make_interp_spline(x, base_lab_colors, k=3)
# Interpolated positions for new colors
x_new = np.linspace(0, 1, additional_colors_needed + 10)
interpolated_lab = bspline(x_new)

# Convert interpolated LAB colors back to RGB
interpolated_rgb = color.lab2rgb(interpolated_lab.reshape(1, -1, 3)).reshape(-1, 3)

return colors.ListedColormap(interpolated_rgb)
36 changes: 18 additions & 18 deletions src/tinytopics/plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import numpy as np
import matplotlib.pyplot as plt
from .colors import scale_color_tinytopics


def plot_loss(losses, figsize=(10, 8), dpi=300, title="Loss curve", output_file=None):
def plot_loss(
losses,
figsize=(10, 8),
dpi=300,
title="Loss curve",
color_palette=None,
output_file=None,
):
"""
Plot the loss curve over training epochs.
Expand All @@ -11,10 +19,14 @@ def plot_loss(losses, figsize=(10, 8), dpi=300, title="Loss curve", output_file=
figsize (tuple, optional): Plot size. Default is (10, 8).
dpi (int, optional): Plot resolution. Default is 300.
title (str, optional): Plot title. Default is "Loss curve".
color_palette (list or matplotlib colormap, optional): Custom color palette.
output_file (str, optional): File path to save the plot. If None, displays the plot.
"""
if color_palette is None:
color_palette = scale_color_tinytopics(1)

plt.figure(figsize=figsize, dpi=dpi)
plt.plot(losses)
plt.plot(losses, color=color_palette(0))
plt.title(title)
plt.xlabel("Epochs")
plt.ylabel("Loss")
Expand Down Expand Up @@ -53,22 +65,16 @@ def plot_structure(
ind = np.arange(n_documents) # Document indices
cumulative = np.zeros(n_documents)

# Color palette
if color_palette is None:
colors = plt.cm.tab20(np.linspace(0, 1, n_topics))
else:
if isinstance(color_palette, list):
colors = color_palette
else:
colors = color_palette(np.linspace(0, 1, n_topics))
color_palette = scale_color_tinytopics(n_topics)

plt.figure(figsize=figsize, dpi=dpi)
for k in range(n_topics):
plt.bar(
ind,
L_matrix[:, k],
bottom=cumulative,
color=colors[k % len(colors)],
color=color_palette(k),
width=1.0,
)
cumulative += L_matrix[:, k]
Expand Down Expand Up @@ -123,14 +129,8 @@ def plot_top_terms(
else:
top_terms_labels = top_terms_indices.astype(str)

# Color palette
if color_palette is None:
colors = plt.cm.tab20(np.linspace(0, 1, n_topics))
else:
if isinstance(color_palette, list):
colors = color_palette
else:
colors = color_palette(np.linspace(0, 1, n_topics))
color_palette = scale_color_tinytopics(n_topics)

# Grid layout
if nrows is None and ncols is None:
Expand All @@ -154,7 +154,7 @@ def plot_top_terms(

# Place highest probability terms at the top
y_pos = np.arange(n_top_terms)[::-1]
ax.barh(y_pos, probs, color=colors[i % len(colors)])
ax.barh(y_pos, probs, color=color_palette(i))
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.set_xlabel("Probability")
Expand Down

0 comments on commit b25e745

Please sign in to comment.