class exponential_approx_multihop_chain_interdependence(chain_interdependence):
def __init__(self, name: str = 'exponential_approx_multihop_chain_interdependence', normalization: bool = False, normalization_mode: str = 'row', *args, **kwargs):
super().__init__(name=name, normalization=normalization, normalization_mode=normalization_mode, *args, **kwargs)
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
if not self.require_data and not self.require_parameters and self.A is not None:
return self.A
else:
adj, mappings = self.chain.to_matrix(normalization=self.normalization, normalization_mode=self.normalization_mode, device=device)
self.node_id_index_map = mappings['node_id_index_map']
self.node_index_id_map = mappings['node_index_id_map']
if adj.device.type == 'mps':
A = torch.matrix_exp(adj.to('cpu')).to('mps')
else:
A = torch.matrix_exp(adj)
A = self.post_process(x=A, device=device)
if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert A.shape == (self.m, self.calculate_m_prime())
elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert A.shape == (self.b, self.calculate_b_prime())
if not self.require_data and not self.require_parameters and self.A is None:
self.A = A
return A