-
Notifications
You must be signed in to change notification settings - Fork 42
/
test.py
81 lines (67 loc) · 3.36 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn.functional as F
import os
import json
from torchmeta.utils.data import BatchMetaDataLoader
from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
def main(args):
with open(args.config, 'r') as f:
config = json.load(f)
if args.folder is not None:
config['folder'] = args.folder
if args.num_steps > 0:
config['num_steps'] = args.num_steps
if args.num_batches > 0:
config['num_batches'] = args.num_batches
device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
benchmark = get_benchmark_by_name(config['dataset'],
config['folder'],
config['num_ways'],
config['num_shots'],
config['num_shots_test'],
hidden_size=config['hidden_size'])
with open(config['model_path'], 'rb') as f:
benchmark.model.load_state_dict(torch.load(f, map_location=device))
meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_test_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
first_order=config['first_order'],
num_adaptation_steps=config['num_steps'],
step_size=config['step_size'],
loss_function=benchmark.loss_function,
device=device)
results = metalearner.evaluate(meta_test_dataloader,
max_batches=config['num_batches'],
verbose=args.verbose,
desc='Test')
# Save results
dirname = os.path.dirname(config['model_path'])
with open(os.path.join(dirname, 'results.json'), 'w') as f:
json.dump(results, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('MAML')
parser.add_argument('config', type=str,
help='Path to the configuration file returned by `train.py`.')
parser.add_argument('--folder', type=int, default=None,
help='Path to the folder the data is downloaded to. '
'(default: path defined in configuration file).')
# Optimization
parser.add_argument('--num-steps', type=int, default=-1,
help='Number of fast adaptation steps, ie. gradient descent updates '
'(default: number of steps in configuration file).')
parser.add_argument('--num-batches', type=int, default=-1,
help='Number of batch of tasks per epoch '
'(default: number of batches in configuration file).')
# Misc
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers to use for data-loading (default: 1).')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--use-cuda', action='store_true')
args = parser.parse_args()
main(args)