class agnews(text_dataloader):
def __init__(self, name='ag_news', train_batch_size=64, test_batch_size=64, max_seq_len: int = 64):
super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, max_seq_len=max_seq_len)
def load(self, *args, **kwargs):
kwargs['xy_reversed'] = True
return super().load(*args, **kwargs)
@staticmethod
def load_datapipe(cache_dir='./data/', *args, **kwargs):
train_datapipe = AG_NEWS(root=cache_dir, split="train")
test_datapipe = AG_NEWS(root=cache_dir, split="test")
return train_datapipe, test_datapipe
@staticmethod
def get_class_number(*args, **kwargs):
return 4
@staticmethod
def get_train_number(*args, **kwargs):
return 120000
@staticmethod
def get_test_number(*args, **kwargs):
return 7600
@staticmethod
def get_idx_to_label(*args, **kwargs):
return {
1: 0,
2: 1,
3: 2,
4: 3,
}