class imagenet(vision_dataloader):
def __init__(self, name='imagenet', train_batch_size: int = 64, test_batch_size: int = 64):
super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size)
# @staticmethod
# def flatten(x):
# x = torch.flatten(x)
# return x.view(-1)
def load(self, cache_dir='./data/', *args, **kwargs):
imagenet_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
torch.flatten
])
train_loader = DataLoader(
ImageNet(root=cache_dir, split='train', transform=imagenet_transform),
batch_size=self.train_batch_size, shuffle=True)
test_loader = DataLoader(
ImageNet(root=cache_dir, split='val', transform=imagenet_transform),
batch_size=self.test_batch_size, shuffle=False)
return {'train_loader': train_loader, 'test_loader': test_loader}