Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 13, 2024
1 parent 9a7e7f4 commit a2d5c08
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torchtrain/metrics_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import torch
import torch.nn as nn

def get_num_params(model: nn.Module, only_trainable: bool = False)-> int:
""" Get the total model params
args: only_trainable - only count trainable params
"""

Get the total model params
Args: only_trainable: whether to only count trainable params
"""
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
Expand Down

0 comments on commit a2d5c08

Please sign in to comment.