Skip to content

sst2

Bases: text_dataloader

Source code in tinybig/data/text_dataloader_torchtext.py
class sst2(text_dataloader):

    def __init__(self, name='sst2', train_batch_size=64, test_batch_size=64, max_seq_len: int = 32):
        super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, max_seq_len=max_seq_len)

    @staticmethod
    def load_datapipe(cache_dir='./data/', *args, **kwargs):
        train_datapipe = SST2(root=cache_dir, split="train")
        test_datapipe = SST2(root=cache_dir, split="dev")
        return train_datapipe, test_datapipe

    @staticmethod
    def get_class_number(*args, **kwargs):
        return 2

    @staticmethod
    def get_train_number(*args, **kwargs):
        return 67349

    @staticmethod
    def get_test_number(*args, **kwargs):
        return 872

    @staticmethod
    def get_idx_to_label(*args, **kwargs):
        return {
            0: 0,
            1: 1,
        }