Skip to content

Commit

Permalink
add support for dumping weights
Browse files Browse the repository at this point in the history
Summary: As titled

Reviewed By: sf-wind

Differential Revision: D49113862

fbshipit-source-id: fce2410cab83e04e79f9bcb66c7959faff5aa090
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed Sep 9, 2023
1 parent 980a10e commit 5e3fdb5
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mobile_cv/arch/utils/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import TypeVar

import torch
from mobile_cv.common.misc import iter_utils as iu

T = TypeVar("T")


def move_to_device(data, device: str):
def move_to_device(data: T, device: str) -> T:
"""Move data to the given device, data could be a nested dict/list"""
diter = iu.recursive_iterate(data, iter_types=torch.Tensor)
for cur in diter:
Expand All @@ -14,7 +18,7 @@ def move_to_device(data, device: str):
return diter.value


def get_cpu_copy(data):
def get_cpu_copy(data: T) -> T:
"""Detach and copy data to cpu, data could be a nested dict/list"""
diter = iu.recursive_iterate(data, iter_types=torch.Tensor)
for cur in diter:
Expand All @@ -26,12 +30,12 @@ def get_cpu_copy(data):
class GPUWrapper(torch.nn.Module):
"""A simple wrapper to move the module to run on GPU"""

def __init__(self, module: torch.nn.Module):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self.module = module.cuda()
self.training = module.training

def forward(self, data: torch.Tensor):
def forward(self, data: torch.Tensor) -> torch.Tensor:
data = data.cuda()
ret = self.module(data)
ret = ret.cpu()
Expand Down

0 comments on commit 5e3fdb5

Please sign in to comment.