Skip to content

Commit

Permalink
chore: add pre-commit configuration file, run ruff-isort (#634)
Browse files Browse the repository at this point in the history
Continuing from #633, this PR adds support for the `pre-commit` via its
configuration file, which works with https://pre-commit.ci.
- I have suppressed all failing rules to get `pre-commit` passing in
68e3dda, so that we can fix them all
later.
- At most, the style changes in this PR are the imports that have been
tidied up and sorted using the `ruff check . --select I --fix && ruff
format .` command.
- The pre-commit config configuration file was added

Additionally, to automatically enable `pre-commit` as a Git hook
locally, type:

```bash
pip install pre-commit
pre-commit install
```

and subsequently, it would not allow creating commits that fail the
style rules we have set in `pyproject.toml` or those in
`.pre-commit-config.yaml`.
  • Loading branch information
agriyakhetarpal authored Aug 30, 2024
1 parent 95ae714 commit f35fcca
Show file tree
Hide file tree
Showing 97 changed files with 332 additions and 277 deletions.
33 changes: 33 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
autofix_commit_msg: "style: pre-commit fixes"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-yaml
exclude: conda_recipe/conda.yaml
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.2"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-check-blanket-type-ignore
exclude: ^src/vector/backends/_numba_object.py$
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal
29 changes: 15 additions & 14 deletions autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from autograd.core import primitive_with_deprecation_warnings as primitive

from .builtins import dict, isinstance, list, tuple, type
from .differential_operators import (
make_vjp,
grad,
multigrad_dict,
checkpoint,
deriv,
elementwise_grad,
value_and_grad,
grad,
grad_and_aux,
grad_named,
hessian,
hessian_tensor_product,
hessian_vector_product,
hessian,
holomorphic_grad,
jacobian,
tensor_jacobian_product,
vector_jacobian_product,
grad_named,
checkpoint,
make_ggnvp,
make_hvp,
make_jvp,
make_ggnvp,
deriv,
holomorphic_grad,
make_vjp,
multigrad_dict,
tensor_jacobian_product,
value_and_grad,
vector_jacobian_product,
)
from .builtins import isinstance, type, tuple, list, dict
from autograd.core import primitive_with_deprecation_warnings as primitive
14 changes: 7 additions & 7 deletions autograd/builtins.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .util import subvals
from .extend import (
Box,
primitive,
notrace_primitive,
VSpace,
vspace,
SparseObject,
defvjp,
defvjp_argnum,
VSpace,
defjvp,
defjvp_argnum,
defvjp,
defvjp_argnum,
notrace_primitive,
primitive,
vspace,
)
from .util import subvals

isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)
Expand Down
9 changes: 4 additions & 5 deletions autograd/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import count
from functools import reduce
from .tracer import trace, primitive, toposort, Node, Box, isbox, getval
from itertools import count

from .tracer import Box, Node, getval, isbox, primitive, toposort, trace
from .util import func, subval

# -------------------- reverse mode --------------------
Expand Down Expand Up @@ -40,9 +41,7 @@ def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
vjpmaker = primitive_vjps[fun]
except KeyError:
fun_name = getattr(fun, "__name__", fun)
raise NotImplementedError(
f"VJP of {fun_name} wrt argnums {parent_argnums} not defined"
)
raise NotImplementedError(f"VJP of {fun_name} wrt argnums {parent_argnums} not defined")
self.vjp = vjpmaker(parent_argnums, value, args, kwargs)

def initialize_root(self):
Expand Down
12 changes: 6 additions & 6 deletions autograd/differential_operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Convenience functions built on top of `make_vjp`."""

from functools import partial
from collections import OrderedDict

try:
Expand All @@ -9,13 +8,14 @@
from inspect import getargspec as _getargspec # Python 2
import warnings

from .wrap_util import unary_to_nary
from .builtins import tuple as atuple
from .core import make_vjp as _make_vjp, make_jvp as _make_jvp
from .extend import primitive, defvjp_argnum, vspace

import autograd.numpy as np

from .builtins import tuple as atuple
from .core import make_jvp as _make_jvp
from .core import make_vjp as _make_vjp
from .extend import defvjp_argnum, primitive, vspace
from .wrap_util import unary_to_nary

make_vjp = unary_to_nary(_make_vjp)
make_jvp = unary_to_nary(_make_jvp)

Expand Down
20 changes: 10 additions & 10 deletions autograd/extend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Exposes API for extending autograd
from .tracer import Box, primitive, register_notrace, notrace_primitive
from .core import (
JVPNode,
SparseObject,
VSpace,
vspace,
VJPNode,
JVPNode,
defvjp_argnums,
defvjp_argnum,
defvjp,
defjvp_argnums,
defjvp_argnum,
defjvp,
VSpace,
def_linear,
defjvp,
defjvp_argnum,
defjvp_argnums,
defvjp,
defvjp_argnum,
defvjp_argnums,
vspace,
)
from .tracer import Box, notrace_primitive, primitive, register_notrace
2 changes: 1 addition & 1 deletion autograd/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracers import const_graph
from .flatten import flatten
from .tracers import const_graph
4 changes: 2 additions & 2 deletions autograd/misc/fixed_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from autograd.extend import primitive, defvjp, vspace
from autograd.builtins import tuple
from autograd import make_vjp
from autograd.builtins import tuple
from autograd.extend import defvjp, primitive, vspace


@primitive
Expand Down
2 changes: 1 addition & 1 deletion autograd/misc/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
arrays. The main purpose is to make examples and optimizers simpler.
"""

import autograd.numpy as np
from autograd import make_vjp
from autograd.builtins import type
import autograd.numpy as np


def flatten(value):
Expand Down
1 change: 0 additions & 1 deletion autograd/misc/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
These routines can optimize functions whose inputs are structured
objects, such as dicts of numpy arrays."""


import autograd.numpy as np
from autograd.misc import flatten
from autograd.wrap_util import wraps
Expand Down
7 changes: 4 additions & 3 deletions autograd/misc/tracers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial
from itertools import repeat
from autograd.wrap_util import wraps

from autograd.tracer import Node, trace
from autograd.util import subvals, toposort
from autograd.tracer import trace, Node
from functools import partial
from autograd.wrap_util import wraps


class ConstGraphNode(Node):
Expand Down
8 changes: 1 addition & 7 deletions autograd/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random
from .numpy_wrapper import *
from . import numpy_boxes
from . import numpy_vspaces
from . import numpy_vjps
from . import numpy_jvps
from . import linalg
from . import fft
from . import random
8 changes: 5 additions & 3 deletions autograd/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex

from autograd.extend import defvjp, primitive, vspace

from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace
from .numpy_vjps import match_complex
from .numpy_wrapper import wrap_namespace

wrap_namespace(ffto.__dict__, globals())

Expand Down
7 changes: 5 additions & 2 deletions autograd/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import partial

import numpy.linalg as npla
from .numpy_wrapper import wrap_namespace

from autograd.extend import defjvp, defvjp

from . import numpy_wrapper as anp
from autograd.extend import defvjp, defjvp
from .numpy_wrapper import wrap_namespace

wrap_namespace(npla.__dict__, globals())

Expand Down
4 changes: 3 additions & 1 deletion autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from autograd.extend import Box, primitive

from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive

from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0
Expand Down
16 changes: 9 additions & 7 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import numpy as onp

from autograd.extend import JVPNode, def_linear, defjvp, defjvp_argnum, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from .numpy_vjps import (
untake,
balanced_eq,
match_complex,
replace_zero,
dot_adjoint_0,
dot_adjoint_1,
match_complex,
nograd_functions,
replace_zero,
tensordot_adjoint_0,
tensordot_adjoint_1,
nograd_functions,
untake,
)
from autograd.extend import defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
register_notrace(JVPNode, fun)
Expand Down
5 changes: 4 additions & 1 deletion autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from functools import partial

import numpy as onp

from autograd.extend import SparseObject, VJPNode, defvjp, defvjp_argnum, primitive, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from autograd.extend import primitive, vspace, defvjp, defvjp_argnum, SparseObject, VJPNode, register_notrace

# ----- Non-differentiable functions -----

Expand Down
3 changes: 2 additions & 1 deletion autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from autograd.extend import VSpace

from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace


class ArrayVSpace(VSpace):
Expand Down
9 changes: 4 additions & 5 deletions autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import types
import warnings
from autograd.extend import primitive, notrace_primitive

import numpy as _np

import autograd.builtins as builtins
from autograd.extend import notrace_primitive, primitive

if _np.lib.NumpyVersion(_np.__version__) >= "2.0.0":
from numpy._core.einsumfunc import _parse_einsum_input
Expand Down Expand Up @@ -75,9 +76,7 @@ def array(A, *args, **kwargs):
def wrap_if_boxes_inside(raw_array, slow_op_name=None):
if raw_array.dtype is _np.dtype("O"):
if slow_op_name:
warnings.warn(
"{} is slow for array inputs. " "np.concatenate() is faster.".format(slow_op_name)
)
warnings.warn("{} is slow for array inputs. " "np.concatenate() is faster.".format(slow_op_name))
return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape)
else:
return raw_array
Expand Down
1 change: 1 addition & 0 deletions autograd/numpy/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy.random as npr

from .numpy_wrapper import wrap_namespace

wrap_namespace(npr.__dict__, globals())
5 changes: 1 addition & 4 deletions autograd/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from . import integrate
from . import signal
from . import special
from . import stats
from . import integrate, signal, special, stats

try:
from . import misc
Expand Down
4 changes: 2 additions & 2 deletions autograd/scipy/integrate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import scipy.integrate

import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
from autograd.extend import defvjp_argnums, primitive
from autograd.misc import flatten

odeint = primitive(scipy.integrate.odeint)

Expand Down
3 changes: 2 additions & 1 deletion autograd/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import partial

import scipy.linalg

import autograd.numpy as anp
from autograd.extend import defjvp, defjvp_argnums, defvjp, defvjp_argnums
from autograd.numpy.numpy_wrapper import wrap_namespace
from autograd.extend import defvjp, defvjp_argnums, defjvp, defjvp_argnums

wrap_namespace(scipy.linalg.__dict__, globals()) # populates module namespace

Expand Down
1 change: 1 addition & 0 deletions autograd/scipy/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import scipy.misc as osp_misc

from ..scipy import special

if hasattr(osp_misc, "logsumexp"):
Expand Down
Loading

0 comments on commit f35fcca

Please sign in to comment.