Skip to content

Commit

Permalink
Implement parameters/named_parameters traversal
Browse files Browse the repository at this point in the history
  * Use recursive dfs
  * Omit initial . in the path (to match with desired test results)
  * Access protected members through getattr
  • Loading branch information
dantp-ai committed Mar 19, 2024
1 parent c1a6b24 commit bdc90c5
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions minitorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,33 @@ def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
Returns:
The name and `Parameter` of each ancestor parameter.
"""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
return self._traverse_tree(self, named=True)

def _traverse_tree(
self, root: Module, named: bool = False, path: str = ""
) -> Sequence[Tuple[str, Parameter] | Parameter]:
result = []
if not root:
return result

parameters = getattr(root, "_parameters", None)
if parameters is not None:
for k, p in parameters.items():
if named:
result.append((f"{path}.{k}" if path else f"{k}", p))
else:
result.append(p)
modules = getattr(root, "_modules", None)
if modules is not None:
for name, v in modules.items():
path_sofar = f"{path}.{name}" if path else name
result.extend(self._traverse_tree(v, named=named, path=path_sofar))

return result

def parameters(self) -> Sequence[Parameter]:
"Enumerate over all the parameters of this module and its descendents."
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
return self._traverse_tree(self)

def add_parameter(self, k: str, v: Any) -> Parameter:
"""
Expand Down

0 comments on commit bdc90c5

Please sign in to comment.