class interdependence(Module, function):
def __init__(
self,
b: int, m: int,
name: str = 'base_interdependency',
interdependence_type: str = 'attribute',
require_data: bool = True,
require_parameters: bool = False,
preprocess_functions=None,
postprocess_functions=None,
preprocess_function_configs=None,
postprocess_function_configs=None,
device: str = 'cpu',
*args, **kwargs
):
Module.__init__(self)
function.__init__(self, name=name, device=device)
self.interdependence_type = interdependence_type
self.b = b
self.m = m
self.require_data = require_data
self.require_parameters = require_parameters
self.preprocess_functions = config.instantiation_functions(preprocess_functions, preprocess_function_configs, device=self.device)
self.postprocess_functions = config.instantiation_functions(postprocess_functions, postprocess_function_configs, device=self.device)
self.A = None
@property
def interdependence_type(self):
return self._interdependence_type
@interdependence_type.setter
def interdependence_type(self, value):
allowed_values = ['instance_interdependence', 'instance', 'left', 'attribute_interdependence', 'attribute', 'right']
if value not in allowed_values:
raise ValueError(f"Invalid value for my_string. Allowed values are: {allowed_values}")
self._interdependence_type = value
def check_A_shape_validity(self, A: torch.Tensor):
if A is None:
raise ValueError("A must be provided")
assert self.interdependence_type is not None and isinstance(self.interdependence_type, str)
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert self.b is not None
assert A.shape == (self.b, self.calculate_b_prime(b=self.b))
elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert self.m is not None
assert A.shape == (self.m, self.calculate_m_prime(m=self.m))
else:
raise ValueError("The interdependence type {self.interdependence_type} is not supported...}")
def get_name(self):
return self.name
def get_A(self):
if self.A is None:
warnings.warn("The A matrix is None...")
return None
else:
return self.A
def get_b(self):
return self.b
def get_m(self):
return self.m
def pre_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
return function.func_x(x, self.preprocess_functions, device=device)
def post_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
return function.func_x(x, self.postprocess_functions, device=device)
def to_config(self):
class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
attributes = {attr: getattr(self, attr) for attr in self.__dict__}
attributes.pop('preprocess_functions')
attributes.pop('postprocess_functions')
if self.preprocess_functions is not None:
attributes['preprocess_function_configs'] = function.functions_to_configs(self.preprocess_functions)
if self.postprocess_functions is not None:
attributes['postprocess_function_configs'] = function.functions_to_configs(self.postprocess_functions)
return {
"function_class": class_name,
"function_parameters": attributes
}
def calculate_l(self):
return 0
def calculate_b_prime(self, b: int = None):
b = b if b is not None else self.b
if self.interdependence_type not in ['row', 'left', 'instance', 'instance_interdependence']:
warnings.warn("The interdependence_type is not about the instances, its b dimension will not be changed...")
return b
def calculate_m_prime(self, m: int = None):
m = m if m is not None else self.m
if self.interdependence_type not in ['column', 'right', 'attribute', 'attribute_interdependence']:
warnings.warn("The interdependence_type is not about the attributes, its m dimension will not be changed...")
return m
def forward(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, kappa_x: torch.Tensor = None, device: str = 'cpu', *args, **kwargs):
if self.require_data:
assert x is not None and x.ndim == 2
if self.require_parameters:
assert w is not None and w.ndim == 2
data_x = kappa_x if kappa_x is not None else x
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
# A shape: b * b'
A = self.calculate_A(x.transpose(0, 1), w, device=device)
assert A is not None and A.size(0) == data_x.size(0)
if data_x.is_sparse or A.is_sparse:
xi_x = torch.sparse.mm(A.t(), data_x)
else:
xi_x = torch.matmul(A.t(), data_x)
return xi_x
elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
# A shape: m * m'
A = self.calculate_A(x, w, device)
assert A is not None and A.size(0) == data_x.size(1)
if data_x.is_sparse or A.is_sparse:
xi_x = torch.sparse.mm(data_x, A)
else:
xi_x = torch.matmul(data_x, A)
return xi_x
else:
raise ValueError(f"Invalid interdependence type: {self.interdependence_type}")
@abstractmethod
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
pass