Skip to content

agnews

Bases: text_dataloader

Source code in tinybig/data/text_dataloader_torchtext.py
class agnews(text_dataloader):

    def __init__(self, name='ag_news', train_batch_size=64, test_batch_size=64):
        super().__init__(name=name, train_batch_size=train_batch_size,
                         test_batch_size=test_batch_size)

    def load(self, *args, **kwargs):
        kwargs['xy_reversed'] = True
        return super().load(*args, **kwargs)

    @staticmethod
    def load_datapipe(cache_dir='./data/', *args, **kwargs):
        train_datapipe = AG_NEWS(root=cache_dir, split="train")
        test_datapipe = AG_NEWS(root=cache_dir, split="test")
        return train_datapipe, test_datapipe

    @staticmethod
    def get_class_number(*args, **kwargs):
        return 4

    @staticmethod
    def get_train_number(*args, **kwargs):
        return 120000

    @staticmethod
    def get_test_number(*args, **kwargs):
        return 7600

    @staticmethod
    def get_idx_to_label(*args, **kwargs):
        return {
            1: 0,
            2: 1,
            3: 2,
            4: 3,
        }