-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
128 lines (101 loc) · 5.71 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Contains all the utilities to load the MNIST dataset
"""
import os
import random
import torch
from torch.utils.data import Subset
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision.transforms import transforms
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from PIL import Image
def get_mnist_dataloaders(root='./datasets/', batch_size=64, digits_to_include: list = None, size=None):
'''
Loads the mnist train and test set into a dataloader
:param root: dir to save dataset
:param batch_size: size of batch
:param digits_to_include: array of labels to be included into the dataset. Default=None (all labels are included)
:return: DataLoader: train_dataloader, DataLoader: test_dataloader.
'''
transform = transforms.Compose([transforms.ToTensor()])
if digits_to_include is None:
mnist_train_dataset = MNIST(root=root, download=True, train=True, transform=transform)
mnist_test_dataset = MNIST(root=root, download=True, train=False, transform=transform)
else:
label_transform = lambda x: digits_to_include.index(x) if x in digits_to_include else -1
mnist_train_dataset = MNIST(root=root, download=True, train=True, transform=transform,
target_transform=label_transform)
mnist_test_dataset = MNIST(root=root, download=True, train=False, transform=transform,
target_transform=label_transform)
train_indices = get_indices(mnist_train_dataset)
test_indices = get_indices(mnist_test_dataset)
mnist_train_dataset = Subset(mnist_train_dataset, train_indices, )
mnist_test_dataset = Subset(mnist_test_dataset, test_indices)
if size is not None:
mnist_train_dataset = Subset(mnist_train_dataset, range(size), )
mnist_test_dataset = Subset(mnist_test_dataset, range(size), )
train_dataloader = DataLoader(mnist_train_dataset, shuffle=True, batch_size=batch_size, )
test_dataloader = DataLoader(mnist_test_dataset, shuffle=False, batch_size=batch_size, )
return train_dataloader, test_dataloader
def get_fmnist_dataloaders(root='./datasets/', batch_size=64, digits_to_include: list = None):
'''
Loads the mnist train and test set into a dataloader
:param root: dir to save dataset
:param batch_size: size of batch
:param digits_to_include: array of labels to be included into the dataset. Default=None (all labels are included)
:return: DataLoader: train_dataloader, DataLoader: test_dataloader.
'''
transform = transforms.Compose([transforms.ToTensor()])
if digits_to_include is None:
fmnist_train_dataset = FashionMNIST(root=root, download=True, train=True, transform=transform)
fmnist_test_dataset = FashionMNIST(root=root, download=True, train=False, transform=transform)
else:
digit_filter_function = lambda x: digits_to_include.index(x) if x in digits_to_include else -1
fmnist_train_dataset = FashionMNIST(root=root, download=True, train=True, transform=transform,
target_transform=digit_filter_function)
fmnist_test_dataset = FashionMNIST(root=root, download=True, train=False, transform=transform,
target_transform=digit_filter_function)
train_indices = get_indices(fmnist_train_dataset)
test_indices = get_indices(fmnist_test_dataset)
fmnist_train_dataset = Subset(fmnist_train_dataset, train_indices, )
fmnist_test_dataset = Subset(fmnist_test_dataset, test_indices)
train_dataloader = DataLoader(fmnist_train_dataset, shuffle=True, batch_size=batch_size, )
test_dataloader = DataLoader(fmnist_test_dataset, shuffle=False, batch_size=batch_size, )
return train_dataloader, test_dataloader
def get_cifar10_dataloaders(root='./datasets/', batch_size=64, digits_to_include: list = None):
'''
Loads the mnist train and test set into a dataloader
:param root: dir to save dataset
:param batch_size: size of batch
:param digits_to_include: array of labels to be included into the dataset. Default=None (all labels are included)
:return: DataLoader: train_dataloader, DataLoader: test_dataloader.
'''
transform = transforms.Compose([transforms.ToTensor()])
if digits_to_include is None:
train_dataset = CIFAR10(root=root, download=True, train=True, transform=transform)
test_dataset = CIFAR10(root=root, download=True, train=False, transform=transform)
else:
digit_filter_function = lambda x: digits_to_include.index(x) if x in digits_to_include else -1
train_dataset = CIFAR10(root=root, download=True, train=True, transform=transform,
target_transform=digit_filter_function)
test_dataset = CIFAR10(root=root, download=True, train=False, transform=transform,
target_transform=digit_filter_function)
train_indices = get_indices(train_dataset)
test_indices = get_indices(test_dataset)
train_dataset = Subset(train_dataset, train_indices, )
test_dataset = Subset(test_dataset, test_indices)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, )
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, )
return train_dataloader, test_dataloader
def get_indices(dataset):
'''
Returns indices of datapoint that have a non-negative label.
:param dataset: dataset to get indices from
:return: list: indices
'''
indices = []
for i, (x, y) in enumerate(dataset):
if y != -1:
indices.append(i)
return indices