class sst2(text_dataloader):
def __init__(self, name='sst2', train_batch_size=64, test_batch_size=64, max_seq_len: int = 32):
super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, max_seq_len=max_seq_len)
@staticmethod
def load_datapipe(cache_dir='./data/', *args, **kwargs):
train_datapipe = SST2(root=cache_dir, split="train")
test_datapipe = SST2(root=cache_dir, split="dev")
return train_datapipe, test_datapipe
@staticmethod
def get_class_number(*args, **kwargs):
return 2
@staticmethod
def get_train_number(*args, **kwargs):
return 67349
@staticmethod
def get_test_number(*args, **kwargs):
return 872
@staticmethod
def get_idx_to_label(*args, **kwargs):
return {
0: 0,
1: 1,
}