-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
23 lines (18 loc) · 819 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from model_scripts import EDSR,RCAN
def get_model( cfg ):
device = cfg.SYSTEM.DEVICE
if cfg.MODEL.NAME == "EDSR":
model = EDSR.EDSR( scale = cfg.DATASET.SR_SCALE)
if cfg.MODEL.PRETRAINED:
file_name = f'./pretrained/EDSR_trained_x{cfg.DATASET.SR_SCALE}.pt'
model.load_state_dict(torch.load(file_name, map_location = device))
return model.to(device)
elif cfg.MODEL.NAME == "RCAN":
model = RCAN.RCAN( scale = cfg.DATASET.SR_SCALE)
if cfg.MODEL.PRETRAINED:
file_name = f'pretrained/RCAN_trained_x{cfg.DATASET.SR_SCALE}.pt'
model.load_state_dict(torch.load(file_name, map_location = device))
return model.to(device)
else:
raise NotImplementedError("This model is not yet supported")