Skip to content

imagenet

Bases: vision_dataloader

Source code in tinybig/data/vision_dataloader.py
class imagenet(vision_dataloader):

    def __init__(self, name='imagenet', train_batch_size: int = 64, test_batch_size: int = 64):
        super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size)

    # @staticmethod
    # def flatten(x):
    #     x = torch.flatten(x)
    #     return x.view(-1)

    def load(self, cache_dir='./data/', *args, **kwargs):
        imagenet_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            torch.flatten
        ])

        train_loader = DataLoader(
            ImageNet(root=cache_dir, split='train', transform=imagenet_transform),
            batch_size=self.train_batch_size, shuffle=True)

        test_loader = DataLoader(
            ImageNet(root=cache_dir, split='val', transform=imagenet_transform),
            batch_size=self.test_batch_size, shuffle=False)

        return {'train_loader': train_loader, 'test_loader': test_loader}