Skip to content

Commit

Permalink
Implement set train and eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Mar 19, 2024
1 parent 35b1edd commit c1a6b24
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions minitorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,24 @@ def modules(self) -> Sequence[Module]:

def train(self) -> None:
"Set the mode of this module and all descendent modules to `train`."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
self._set_mode(self, train=True)

def _set_mode(self, root: Module, train: bool = False) -> None:
if not root:
return

if train:
root.training = True
else:
root.training = False
modules = getattr(root, "_modules", None)
if modules is not None:
for _, v in modules.items():
self._set_mode(v, train)

def eval(self) -> None:
"Set the mode of this module and all descendent modules to `eval`."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
self._set_mode(self, train=False)

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
"""
Expand Down

0 comments on commit c1a6b24

Please sign in to comment.