Skip to content

Commit

Permalink
Fix Mypy Linter Errors in plotting.py (#405)
Browse files Browse the repository at this point in the history
Summary:
This PR resolves mypy linter issues in `plotting.py` by adding `None` checks for `strat.model`, providing a missing type hint for the `locs` variable, and updating `matplotlib.markers` to use a string digit instead of an integer.

This was done because adding more strict type hints in other parts of the code revealed these errors, which need to be fixed before proceeding with #403 .

Pull Request resolved: #405

Reviewed By: crasanders

Differential Revision: D64415474

Pulled By: JasonKChow

fbshipit-source-id: 1bbb3cc28d7c15b193404f2ca6407a6996d4d062
  • Loading branch information
yalsaffar authored and facebook-github-bot committed Oct 16, 2024
1 parent 45d8e2d commit 8f38b4a
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ def _plot_strat_1d(
x, y = strat.x, strat.y
assert x is not None and y is not None, "No data to plot!"

grid = strat.model.dim_grid(gridsize=gridsize)
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
if strat.model is not None:
grid = strat.model.dim_grid(gridsize=gridsize)
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
else:
raise RuntimeError("Cannot plot without a model!")

ax.plot(np.squeeze(grid), phimean)
if cred_level is not None:
Expand Down Expand Up @@ -215,14 +218,14 @@ def _plot_strat_1d(
ax.scatter(
x[y == 0, 0],
np.zeros_like(x[y == 0, 0]),
marker=3,
marker="3",
color="r",
label=no_label,
)
ax.scatter(
x[y == 1, 0],
np.zeros_like(x[y == 1, 0]),
marker=3,
marker="3",
color="b",
label=yes_label,
)
Expand Down Expand Up @@ -253,11 +256,14 @@ def _plot_strat_2d(
assert x is not None and y is not None, "No data to plot!"

# make sure the model is fit well if we've been limiting fit time
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)
if strat.model is not None:
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)

grid = strat.model.dim_grid(gridsize=gridsize)
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
grid = strat.model.dim_grid(gridsize=gridsize)
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
else:
raise RuntimeError("Cannot plot without a model!")

extent = np.r_[strat.lb[0], strat.ub[0], strat.lb[1], strat.ub[1]]
colormap = ax.imshow(
Expand All @@ -277,7 +283,7 @@ def _plot_strat_2d(

# hacky relabel to be in logspace
if logx:
locs = np.arange(strat.lb[0], strat.ub[0])
locs: np.ndarray = np.arange(strat.lb[0], strat.ub[0])
ax.set_xticks(ticks=locs)
ax.set_xticklabels(2.0**locs)

Expand Down

0 comments on commit 8f38b4a

Please sign in to comment.