Skip to content

graph_interdependence

Bases: interdependence

Source code in tinybig/interdependence/topological_interdependence.py
class graph_interdependence(interdependence):

    def __init__(
        self,
        b: int, m: int,
        interdependence_type: str = 'instance',
        name: str = 'graph_interdependence',
        graph: graph_structure = None,
        nodes: list = None, links: list = None, directed: bool = True,
        normalization: bool = False, normalization_mode: str = 'row',
        self_dependence: bool = False,
        require_data: bool = False, require_parameters: bool = False,
        device: str = 'cpu', *args, **kwargs
    ):
        super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=require_data, require_parameters=require_parameters, device=device, *args, **kwargs)

        if graph is not None:
            self.graph = graph
        elif nodes is not None and links is not None:
            self.graph = graph_structure(nodes=nodes, links=links, directed=directed)
        else:
            raise ValueError('Either nodes or links must be provided')

        self.node_id_index_map = None
        self.node_index_id_map = None

        self.normalization = normalization
        self.normalization_mode = normalization_mode
        self.self_dependence = self_dependence

    def get_node_index_id_map(self):
        if self.node_index_id_map is None:
            warnings.warn("The mapping has not been assigned yet, please call the calculate_A method first...")
        return self.node_index_id_map

    def get_node_id_index_map(self):
        if self.node_id_index_map is None:
            warnings.warn("The mapping has not been assigned yet, please call the calculate_A method first...")
        return self.node_id_index_map

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        if not self.require_data and not self.require_parameters and self.A is not None:
            return self.A
        else:
            adj, mappings = self.graph.to_matrix(self_dependence=self.self_dependence, normalization=self.normalization, normalization_mode=self.normalization_mode, device=device)

            self.node_id_index_map = mappings['node_id_index_map']
            self.node_index_id_map = mappings['node_index_id_map']

            A = self.post_process(x=adj, device=device)

            if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
                assert A.shape == (self.m, self.calculate_m_prime())
            elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
                assert A.shape == (self.b, self.calculate_b_prime())

            if not self.require_data and not self.require_parameters and self.A is None:
                self.A = A
            return A