Skip to content

Commit

Permalink
Merge pull request #43 from yb6599/master
Browse files Browse the repository at this point in the history
Derivative.d now handles negative axis arguments
  • Loading branch information
andgoldschmidt authored Jun 19, 2024
2 parents 9474e7f + 819f710 commit b5c9add
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
5 changes: 4 additions & 1 deletion derivative/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra
return dX.flatten()
else:
# order of operations coupled with _align_axes
extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis)
orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args
extra_dims = tuple(
length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis
)
moved_shape = (orig_shape[axis],) + extra_dims
dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis)
return dX
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python = "^3.9"
numpy = "^1.18.3"
scipy = "^1.4.1"
scikit-learn = "^1"
importlib-metadata = "^7.1.0"

# docs
sphinx = {version = "^5", optional = true}
Expand Down
13 changes: 12 additions & 1 deletion tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,15 @@ def test_hyperparam_entrypoint():
func = utils._load_hyperparam_func("kalman.default")
expected = 1
result = func(None, None)
assert result == expected
assert result == expected


def test_negative_axis():
t = np.arange(3)
x = np.random.random(size=(2, 3, 2))
x[1, :, 1] = 1
axis = -2
expected = np.zeros(3)
dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1)
assert x.shape == dx.shape
np.testing.assert_array_almost_equal(dx[1, :, 1], expected)

0 comments on commit b5c9add

Please sign in to comment.