class chain_interdependence(interdependence):
def __init__(
self,
b: int, m: int,
interdependence_type: str = 'instance',
name: str = 'chain_interdependence',
chain: chain_structure = None,
chain_length: int = None, bi_directional: bool = False,
normalization: bool = False, normalization_mode: str = 'row',
self_dependence: bool = True, self_scaling: float = 1.0,
require_data: bool = False, require_parameters: bool = False,
device: str = 'cpu', *args, **kwargs
):
super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=require_data, require_parameters=require_parameters, device=device, *args, **kwargs)
if chain is not None:
self.chain = chain
elif chain_length is not None:
self.chain = chain_structure(length=chain_length, bi_directional=bi_directional)
else:
raise ValueError('Either chain structure of chain length must be provided...')
self.node_id_index_map = None
self.node_index_id_map = None
self.normalization = normalization
self.normalization_mode = normalization_mode
self.self_dependence = self_dependence
self.self_scaling = self_scaling
def is_bi_directional(self):
return not self.chain.is_directed()
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
if not self.require_data and not self.require_parameters and self.A is not None:
return self.A
else:
adj, mappings = self.chain.to_matrix(
self_dependence=self.self_dependence,
self_scaling=self.self_scaling,
normalization=self.normalization,
normalization_mode=self.normalization_mode,
device=device
)
self.node_id_index_map = mappings['node_id_index_map']
self.node_index_id_map = mappings['node_index_id_map']
A = self.post_process(x=adj, device=device)
if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert A.shape == (self.m, self.calculate_m_prime())
elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert A.shape == (self.b, self.calculate_b_prime())
if not self.require_data and not self.require_parameters and self.A is None:
self.A = A
return A