-
Notifications
You must be signed in to change notification settings - Fork 6
/
data.py
19 lines (18 loc) · 841 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torchvision.datasets import CIFAR100
from torchvision import transforms
import torch
import config as cf
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding = 4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(cf.mean['cifar100'], cf.std['cifar100'])
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cf.mean['cifar100'], cf.std['cifar100'])
])
trainset = CIFAR100(root=cf.root, train=True, download=True, transform=transform_train)
testset = CIFAR100(root=cf.root, train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=cf.batch_size, shuffle=True, num_workers=8)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)