Skip to content

Commit

Permalink
Implement forward pass scalar functions:
Browse files Browse the repository at this point in the history
  * add
  * lt
  * gt
  * eq
  * sub
  * neg
  * log
  * exp
  * sigmoid
  * relu
  • Loading branch information
dantp-ai committed Mar 22, 2024
1 parent b35d173 commit 5810e9a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 45 deletions.
46 changes: 18 additions & 28 deletions minitorch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import numpy as np

from .autodiff import Context, Variable, backpropagate, central_difference
from .scalar_functions import EQ # noqa: F401
from .scalar_functions import LT # noqa: F401
from .scalar_functions import Add # noqa: F401
from .scalar_functions import Exp # noqa: F401
from .scalar_functions import Log # noqa: F401
from .scalar_functions import Neg # noqa: F401
from .scalar_functions import ReLU # noqa: F401
from .scalar_functions import Sigmoid # noqa: F401
from .scalar_functions import EQ
from .scalar_functions import LT
from .scalar_functions import Add
from .scalar_functions import Exp
from .scalar_functions import Log
from .scalar_functions import Neg
from .scalar_functions import ReLU
from .scalar_functions import Sigmoid
from .scalar_functions import (
Inv,
Mul,
Expand Down Expand Up @@ -92,31 +92,25 @@ def __rtruediv__(self, b: ScalarLike) -> Scalar:
return Mul.apply(b, Inv.apply(self))

def __add__(self, b: ScalarLike) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Add.apply(self, b)

def __bool__(self) -> bool:
return bool(self.data)

def __lt__(self, b: ScalarLike) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return LT.apply(self, b)

def __gt__(self, b: ScalarLike) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return LT.apply(b, self)

def __eq__(self, b: ScalarLike) -> Scalar: # type: ignore[override]
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return EQ.apply(self, b)

def __sub__(self, b: ScalarLike) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Add.apply(self, -b)

def __neg__(self) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Neg.apply(self.data)

def __radd__(self, b: ScalarLike) -> Scalar:
return self + b
Expand All @@ -125,20 +119,16 @@ def __rmul__(self, b: ScalarLike) -> Scalar:
return self * b

def log(self) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Log.apply(self.data)

def exp(self) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Exp.apply(self.data)

def sigmoid(self) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return Sigmoid.apply(self.data)

def relu(self) -> Scalar:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
return ReLU.apply(self.data)

# Variable elements for backprop

Expand Down
34 changes: 17 additions & 17 deletions minitorch/scalar_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Add(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float, b: float) -> float:
return a + b
return operators.add(a, b)

@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, ...]:
Expand Down Expand Up @@ -103,8 +103,8 @@ class Mul(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float, b: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(b)
return operators.mul(a, b)

@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
Expand All @@ -117,8 +117,8 @@ class Inv(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a)
return operators.inv(a)

@staticmethod
def backward(ctx: Context, d_output: float) -> float:
Expand All @@ -131,8 +131,8 @@ class Neg(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a)
return operators.neg(a)

@staticmethod
def backward(ctx: Context, d_output: float) -> float:
Expand All @@ -145,8 +145,8 @@ class Sigmoid(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a)
return operators.sigmoid(a)

@staticmethod
def backward(ctx: Context, d_output: float) -> float:
Expand All @@ -159,8 +159,8 @@ class ReLU(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a)
return operators.relu(a)

@staticmethod
def backward(ctx: Context, d_output: float) -> float:
Expand All @@ -173,8 +173,8 @@ class Exp(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a)
return operators.exp(a)

@staticmethod
def backward(ctx: Context, d_output: float) -> float:
Expand All @@ -187,8 +187,8 @@ class LT(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float, b: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a, b)
return operators.lt(a, b)

@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
Expand All @@ -201,8 +201,8 @@ class EQ(ScalarFunction):

@staticmethod
def forward(ctx: Context, a: float, b: float) -> float:
# TODO: Implement for Task 1.2.
raise NotImplementedError("Need to implement for Task 1.2")
ctx.save_for_backward(a, b)
return operators.eq(a, b)

@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
Expand Down

0 comments on commit 5810e9a

Please sign in to comment.