Skip to content

citeseer

Bases: graph_dataloader

Source code in tinybig/data/graph_dataloader.py
class citeseer(graph_dataloader):
    def __init__(self, name: str = 'citeseer', train_batch_size: int = 64, test_batch_size: int = 64, *args, **kwargs):
        super().__init__(data_profile=CITESEER_DATA_PROFILE, name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size)

    def get_train_test_idx(self, X: torch.Tensor = None, y: torch.Tensor = None, *args, **kwargs):
        train_idx = torch.LongTensor(range(120))
        test_idx = torch.LongTensor(range(200, 1200))
        return train_idx, test_idx