diff --git a/minitorch/autodiff.py b/minitorch/autodiff.py index 9431908..3a52170 100644 --- a/minitorch/autodiff.py +++ b/minitorch/autodiff.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Iterable, List, Tuple # noqa: F401 +from typing import Any, Iterable, Tuple from typing_extensions import Protocol @@ -22,8 +22,10 @@ def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6) Returns: An approximation of $f'_i(x_0, \ldots, x_{n-1})$ """ - # TODO: Implement for Task 1.1. - raise NotImplementedError("Need to implement for Task 1.1") + return ( + f(*[v + epsilon if i == arg else v for i, v in enumerate(vals)]) + - f(*[v - epsilon if i == arg else v for i, v in enumerate(vals)]) + ) / (2 * epsilon) variable_count = 1