Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI failures #155

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion schema.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]) -> T:

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = dict[Axis, npt.NDArray[np.floating[Any]]]
AxesPoints = dict[Axis, npt.NDArray[np.float64]]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -337,9 +337,9 @@ def concat(self, other: Frames[Axis], gap: bool = False) -> Frames[Axis]:
{'x': array([1, 2, 3, 4, 5, 6]), 'y': array([6, 5, 4, 3, 2, 1])}

"""
assert set(self.axes()) == set(
other.axes()
), f"axes {self.axes()} != {other.axes()}"
assert set(self.axes()) == set(other.axes()), (
f"axes {self.axes()} != {other.axes()}"
)

def concat_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# Concat each array in midpoints, lower, upper. E.g.
Expand Down
36 changes: 17 additions & 19 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
__all__ = ["plot_spec"]


def _plot_arrays(
axes: Axes, arrays: list[npt.NDArray[np.floating[Any]]], **kwargs: Any
):
def _plot_arrays(axes: Axes, arrays: list[npt.NDArray[np.float64]], **kwargs: Any):
if len(arrays) > 2:
axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs) # type: ignore
elif len(arrays) == 2:
Expand All @@ -34,35 +32,35 @@
class Arrow3D(patches.FancyArrowPatch):
def __init__(
self,
xs: npt.NDArray[np.floating[Any]],
ys: npt.NDArray[np.floating[Any]],
zs: npt.NDArray[np.floating[Any]],
xs: npt.NDArray[np.float64],
ys: npt.NDArray[np.float64],
zs: npt.NDArray[np.float64],
*args: Any,
**kwargs: Any,
):
super().__init__((0, 0), (0, 0), *args, **kwargs) # type: ignore
self._verts3d = xs, ys, zs

# Added here because of https://github.com/matplotlib/matplotlib/issues/21688
def do_3d_projection(self, renderer: Any = None):
def do_3d_projection(self, renderer: Any = None): # type: ignore
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) # type: ignore

Check warning on line 48 in src/scanspec/plot.py

View check run for this annotation

Codecov / codecov/patch

src/scanspec/plot.py#L48

Added line #L48 was not covered by tests

return np.min(zs)
return np.min(zs) # type: ignore

Check warning on line 50 in src/scanspec/plot.py

View check run for this annotation

Codecov / codecov/patch

src/scanspec/plot.py#L50

Added line #L50 was not covered by tests

@property
def verts3d(
self,
) -> tuple[
npt.NDArray[np.floating[Any]],
npt.NDArray[np.floating[Any]],
npt.NDArray[np.floating[Any]],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
]:
return self._verts3d


def _plot_arrow(axes: Axes, arrays: list[npt.NDArray[np.floating[Any]]]):
def _plot_arrow(axes: Axes, arrays: list[npt.NDArray[np.float64]]):
if len(arrays) == 1:
arrays = [np.array([0, 0])] + arrays
if len(arrays) == 2:
Expand All @@ -83,9 +81,9 @@
def _plot_spline(
axes: Axes,
ranges: list[float],
arrays: list[npt.NDArray[np.floating[Any]]],
arrays: list[npt.NDArray[np.float64]],
index_colours: dict[int, str],
) -> Iterable[list[npt.NDArray[np.floating[Any]]]]:
) -> Iterable[list[npt.NDArray[np.float64]]]:
scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)]
# Define curves parametrically
t = np.zeros(len(arrays[0]))
Expand All @@ -107,7 +105,7 @@
start_value: float = t[start]
stop_value: float = t[stop]
tnew = np.linspace(start_value, stop_value, num=1001)
spline: npt.NDArray[np.floating[Any]] = interpolate.splev(tnew, tck) # type: ignore
spline: npt.NDArray[np.float64] = interpolate.splev(tnew, tck) # type: ignore
# Scale the splines back to the original scaling
unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)]
_plot_arrays(axes, unscaled_splines, color=index_colours[start])
Expand Down Expand Up @@ -190,7 +188,7 @@
plt_axes.add_patch(patches.Polygon(xy_verts, fill=False))

# Plot the splines
tail: dict[str, npt.NDArray[np.floating[Any]] | None] = {a: None for a in axes}
tail: dict[str, npt.NDArray[np.float64] | None] = {a: None for a in axes}
ranges = [max(float(np.max(v) - np.min(v)), 0.0001) for v in dim.midpoints.values()]
seg_col = cycle(colors.TABLEAU_COLORS)
last_index = 0
Expand All @@ -200,8 +198,8 @@
gap_indices = list(np.nonzero(dim.gap[1:])[0] + 1)
for index in gap_indices + [len(dim)]:
num_points = index - last_index
arrays: list[npt.NDArray[np.floating[Any]]] = []
turnaround: list[npt.NDArray[np.floating[Any]]] = []
arrays: list[npt.NDArray[np.float64]] = []
turnaround: list[npt.NDArray[np.float64]] = []
for a in axes:
# Add the midpoints and the lower and upper bounds
arr = np.empty(num_points * 2 + 1)
Expand Down
16 changes: 8 additions & 8 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def calculate( # noqa: D102
) -> list[Frames[Axis]]:
frames_left = self.left.calculate(bounds, nested)
frames_right = self.right.calculate(bounds, nested)
assert len(frames_left) >= len(
frames_right
), f"Zip requires len({self.left}) >= len({self.right})"
assert len(frames_left) >= len(frames_right), (
f"Zip requires len({self.left}) >= len({self.right})"
)

# Pad and expand the right to be the same size as left. Special case, if
# only one Frames object with size 1, expand to the right size
Expand All @@ -262,9 +262,9 @@ def calculate( # noqa: D102
combined = left
else:
combined = left.zip(right)
assert isinstance(
combined, Frames
), f"Padding went wrong {frames_left} {padded_right}"
assert isinstance(combined, Frames), (
f"Padding went wrong {frames_left} {padded_right}"
)
frames.append(combined)
return frames

Expand Down Expand Up @@ -456,11 +456,11 @@ def _dimensions_from_indexes(
bounds: bool,
) -> list[Frames[Axis]]:
# Calc num midpoints (fences) from 0.5 .. num - 0.5
midpoints_calc = func(np.linspace(0.5, num - 0.5, num))
midpoints_calc = func(np.linspace(0.5, num - 0.5, num, dtype=np.float64))
tpoliaw marked this conversation as resolved.
Show resolved Hide resolved
midpoints = {a: midpoints_calc[a] for a in axes}
if bounds:
# Calc num + 1 bounds (posts) from 0 .. num
bounds_calc = func(np.linspace(0, num, num + 1))
bounds_calc = func(np.linspace(0, num, num + 1, dtype=np.float64))
lower = {a: bounds_calc[a][:-1] for a in axes}
upper = {a: bounds_calc[a][1:] for a in axes}
# Points must have no gap as upper[a][i] == lower[a][i+1]
Expand Down
Loading