Skip to content

multihop_chain_interdependence

Bases: chain_interdependence

Source code in tinybig/interdependence/topological_interdependence.py
class multihop_chain_interdependence(chain_interdependence):

    def __init__(self, h: int = 1, accumulative: bool = False, name: str = 'multihop_chain_interdependence', *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        self.h = h
        self.accumulative = accumulative

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        if not self.require_data and not self.require_parameters and self.A is not None:
            return self.A
        else:
            adj, mappings = self.chain.to_matrix(self_dependence=self.self_dependence, normalization=self.normalization, normalization_mode=self.normalization_mode, device=device)
            self.node_id_index_map = mappings['node_id_index_map']
            self.node_index_id_map = mappings['node_index_id_map']

            if self.accumulative:
                A = accumulative_matrix_power(adj, self.h)
                if self.is_bi_directional():
                    A = degree_based_normalize_matrix(A, mode='column')
            else:
                A = matrix_power(adj, self.h)

            A = self.post_process(x=A, device=device)

            if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
                assert A.shape == (self.m, self.calculate_m_prime())
            elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
                assert A.shape == (self.b, self.calculate_b_prime())

            if not self.require_data and not self.require_parameters and self.A is None:
                self.A = A

            return A