Skip to content

imdb

Bases: text_dataloader

Source code in tinybig/data/text_dataloader_torchtext.py
class imdb(text_dataloader):

    def __init__(self, name='imdb', train_batch_size=64, test_batch_size=64, max_seq_len: int = 512):
        super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, max_seq_len=max_seq_len)

    def load(self, *args, **kwargs):
        kwargs['xy_reversed'] = True
        return super().load(*args, **kwargs)

    @staticmethod
    def load_datapipe(cache_dir='./data/', *args, **kwargs):
        train_datapipe = IMDB(root=cache_dir, split="train")
        test_datapipe = IMDB(root=cache_dir, split="test")
        return train_datapipe, test_datapipe

    @staticmethod
    def get_class_number(*args, **kwargs):
        return 2

    @staticmethod
    def get_train_number(*args, **kwargs):
        return 25000

    @staticmethod
    def get_test_number(*args, **kwargs):
        return 25000

    @staticmethod
    def get_idx_to_label(*args, **kwargs):
        return {
            1: 0,
            2: 1,
        }