Skip to content

graph_dataloader

Bases: dataloader

Source code in tinybig/data/graph_dataloader.py
class graph_dataloader(dataloader):

    def __init__(self, data_profile: dict = None, name: str = 'graph_data', train_batch_size: int = 64, test_batch_size: int = 64):
        super().__init__(name=name, train_batch_size=train_batch_size, test_batch_size=test_batch_size)

        self.data_profile = data_profile
        self.graph = None

    @staticmethod
    def download_data(data_profile: dict, cache_dir: str = None, file_name: str = None):
        if data_profile is None:
            raise ValueError('The data profile must be provided.')

        if cache_dir is None:
            cache_dir = './data/'

        if data_profile is None or 'url' not in data_profile:
            raise ValueError('data_profile must not be None and should contain "url" key...')

        if file_name is None:
            for file_name in data_profile['url']:
                download_file_from_github(url_link=data_profile['url'][file_name], destination_path="{}/{}".format(cache_dir, file_name))
        else:
            assert file_name in data_profile['url']
            download_file_from_github(url_link=data_profile['url'][file_name], destination_path="{}/{}".format(cache_dir, file_name))


    def load_raw(self, cache_dir: str, device: str = 'cpu', normalization: bool = True, normalization_mode: str = 'row'):
        if not check_file_existence("{}/node".format(cache_dir)):
            self.download_data(data_profile=self.data_profile, cache_dir=cache_dir, file_name='node')
        if not check_file_existence("{}/link".format(cache_dir)):
            self.download_data(data_profile=self.data_profile, cache_dir=cache_dir, file_name='link')

        idx_features_labels = np.genfromtxt("{}/node".format(cache_dir), dtype=np.dtype(str))
        X = torch.tensor(sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32).todense())
        y = dataloader.encode_str_labels(labels=idx_features_labels[:, -1], one_hot=False)

        if normalization:
            X = degree_based_normalize_matrix(mx=X, mode=normalization_mode)

        nodes = np.array(idx_features_labels[:, 0], dtype=np.int32).tolist()
        links = np.genfromtxt("{}/link".format(cache_dir), dtype=np.int32).tolist()
        graph = graph_class(
            nodes=nodes, links=links, directed=True, device=device
        )
        return graph, X, y

    def save_graph(self, complete_path: str, graph: graph_class = None):
        graph = graph if graph is not None else self.graph
        if graph is None:
            raise ValueError('The graph structure has not been loaded yet...')
        if complete_path is None:
            raise ValueError('The cache complete_path has not been set yet...')
        return graph.save(complete_path=complete_path)

    def load_graph(self, complete_path: str):
        if complete_path is None:
            raise ValueError('The cache complete_path has not been set yet...')
        self.graph = graph_class.load(complete_path=complete_path)
        return self.graph

    def get_graph(self):
        return self.graph

    def get_adj(self, graph: graph_class = None):
        graph = graph if graph is not None else self.graph
        if graph is None:
            raise ValueError('The graph structure has not been loaded yet...')
        return graph.to_matrix(
            normalization=True,
            normalization_mode='row',
        )

    def load(self, mode: str = 'transductive', cache_dir: str = None, device: str = 'cpu',
             train_percentage: float = 0.5, random_state: int = 1234, shuffle: bool = False, *args, **kwargs):

        cache_dir = cache_dir if cache_dir is not None else "./data/{}".format(self.name)
        self.graph, X, y = self.load_raw(cache_dir=cache_dir, device=device)

        if mode == 'transductive':
            warnings.warn("For transductive settings, the train, test, and val partition will not follow the provided parameters (e.g., train percentage, batch size, etc.)...")
            train_idx, test_idx = self.get_train_test_idx(X=X, y=y)
            complete_dataset = dataset(X, y)
            complete_dataloader = DataLoader(dataset=complete_dataset, batch_size=len(X), shuffle=False)
            return {'train_idx': train_idx, 'test_idx': test_idx, 'train_loader': complete_dataloader, 'test_loader': complete_dataloader, 'graph_structure': self.graph}
        else:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y,
                train_size=int(train_percentage * len(X)),
                random_state=random_state, shuffle=shuffle
            )
            train_dataset = dataset(X_train, y_train)
            test_dataset = dataset(X_test, y_test)
            if self.train_batch_size >= 1:
                train_loader = DataLoader(dataset=train_dataset, batch_size=self.train_batch_size, shuffle=True)
            else:
                train_loader = DataLoader(dataset=train_dataset, batch_size=len(X_train), shuffle=True)
            if self.test_batch_size >= 1:
                test_loader = DataLoader(dataset=test_dataset, batch_size=self.test_batch_size, shuffle=False)
            else:
                test_loader = DataLoader(dataset=test_dataset, batch_size=len(X_test), shuffle=False)
            return {'train_loader': train_loader, 'test_loader': test_loader, 'graph_structure': self.graph}

    @abstractmethod
    def get_train_test_idx(self, X: torch.Tensor = None, y: torch.Tensor = None, *args, **kwargs):
        pass